diff --git a/examples/agent_examples/whittington_2020_example.ipynb b/examples/agent_examples/whittington_2020_example.ipynb new file mode 100644 index 00000000..7152dfc6 --- /dev/null +++ b/examples/agent_examples/whittington_2020_example.ipynb @@ -0,0 +1,305 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# The Tolman-Eichenbaum Machine: Unifying Space and Relational Memory through Generalization in the Hippocampal Formation.\n", + "James C.R.Whittington, Timothy H.Muller, Shirley Mark, Guifen Chen, Caswell Barry, Neil Burgess, Timothy E.J.Behrens.\n", + "\n", + "[![All Contributors](https://img.shields.io/badge/all_contributors-4-orange.svg?style=flat-square)](#contributors-)\n", + "\n", + "\n", + "https://doi.org/10.1016/j.cell.2020.10.024\n", + "\n", + "For a more detailed explanation of TEM's theory and implementation in theis framework, see [Description, Implementation & Analysis of the Tolman-eichenbaum Machine](https://github.com/LukeHollingsworth/Tollman-Eichenbaum-Implementation/blob/main/Description%2C%20Implementation%20and%20Analysis%20of%20the%20Tollman-Eichenbaum%20Machine.pdf).\n", + "\n", + "#### TEM Virtual Environment\n", + "This implementation uses an older version of TensorFlow (1.9.0) and subsequently requires Python version 3.6. We suggest setting up a virtual (conda) environment with the following packages:\n", + "```\n", + "Python>=3.8\n", + "PyTorch\n", + "\n", + "matplotlib\n", + "numpy\n", + "tqdm\n", + "```\n", + "\n", + "#### Training TEM\n", + "The TEM model is run from the [run script file](whittington_2020_run.py). The model is initialised with default parameters, identical to those in the original publication; this can be found and changed in the [parameters file](../../neuralplayground/agents/whittington_2020_extras/whittington_2020_parameters.py). The model itself can be found [here](../../neuralplayground/agents/whittington_2020_extras/whittington_2020_model.py).\n", + "\n", + "Training should be done on a high-powered computing platform. With 20GB of memory, TEM completes it's full 20,000 iterations of training in approx. 48 hours." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Plotting Results\n", + "The notebook below plots TEM model performance and its key neural representations, such as grid cells and place cells. Example results are given, however the user is abe so plot the result of locally saved models as well.\n", + "\n", + "First, import the relevant standard and NeuralPlayground Libraries." + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import numpy as np\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "from importlib import util\n", + "import neuralplayground.agents.whittington_2020_extras.whittington_2020_parameters as parameters\n", + "import neuralplayground.agents.whittington_2020_extras.whittington_2020_analyse as analyse\n", + "import neuralplayground.agents.whittington_2020_extras.whittington_2020_plot as plot" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Select trained model\n", + "Choose the folder which contains the saved data from the trained TEM model. Run 0 is 1/40th of full training, run 1 is 1/20th and run 2 is the full 20,000 iterations of training." + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "metadata": {}, + "outputs": [], + "source": [ + "pars_orig = parameters.parameters()\n", + "params = pars_orig.copy()\n", + "\n", + "date = 'example'\n", + "run = '2'\n", + "model_path = os.path.abspath(os.pardir) + '/Summaries/' + date + '/torch_run' + run + '/model/'\n", + "save_path = os.path.abspath(os.pardir) + '/Summaries/' + date + '/torch_run' + run + '/save/'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Load variables\n", + "TEM model saves grid and place representations, as well as information on accuracy and zero-shot inference, at the end of training." + ] + }, + { + "cell_type": "code", + "execution_count": 85, + "metadata": {}, + "outputs": [], + "source": [ + "environments = torch.load(save_path + 'environments')\n", + "g = torch.load(save_path + 'g_all')\n", + "p = torch.load(save_path + 'p_all')\n", + "correct_model, correct_node, correct_edge = torch.load(save_path + 'correct_all')\n", + "zero_shot = torch.load(save_path + 'zero_shot')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Choose which environment to plot\n", + "For reasons of efficiency, only the trajectory of one environment (within a batch) is saved and can be loaded." + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "metadata": {}, + "outputs": [], + "source": [ + "env_to_plot = 0\n", + "environment = environments[env_to_plot]\n", + "shiny_envs = [False for _ in range(16)]\n", + "envs_to_avg = shiny_envs if shiny_envs[env_to_plot] else [not shiny_env for shiny_env in shiny_envs]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Plot TEM prediction accuracy\n", + "The prediction accuracy of TEM is plotted for the final 5000 steps in training trajectory. This is plotted alongside the proportion of nodes and edges visitied." + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0.5, 1.0, 'Zero-shot inference: 26.454293628808866%')" + ] + }, + "execution_count": 87, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Plot results of agent comparison and zero-shot inference analysis\n", + "filt_size = 41\n", + "plt.figure()\n", + "plt.plot(analyse.smooth(np.mean(np.array([env for env_i, env in enumerate(correct_model) if envs_to_avg[env_i]]),0)[1:], filt_size), label='tem')\n", + "plt.plot(analyse.smooth(np.mean(np.array([env for env_i, env in enumerate(correct_node) if envs_to_avg[env_i]]),0)[1:], filt_size), label='node')\n", + "plt.plot(analyse.smooth(np.mean(np.array([env for env_i, env in enumerate(correct_edge) if envs_to_avg[env_i]]),0)[1:], filt_size), label='edge')\n", + "plt.ylim(0, 1)\n", + "plt.legend()\n", + "plt.title('Zero-shot inference: ' + str(np.mean([np.mean(env) for env_i, env in enumerate(zero_shot) if envs_to_avg[env_i]]) * 100) + '%')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Plot all cells\n", + "Plot both hippocampal and entorhinal cells for all frequency modules." + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot.plot_cells(p[env_to_plot], g[env_to_plot], environments[env_to_plot], n_f_ovc=(params['n_f_ovc'] if 'n_f_ovc' in params else 0), columns = 25)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Plot single cell\n", + "Plot hippocampal or entorhinal cells at a specific frequency module." + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAQvklEQVR4nO3by68t6V3f4V/Vuq+9zz433HaDu21jJ5Jpu4kM4iacyAoCgQhCyoB/IFKkRMo4k2TCNJMMcxHTjOIMQgTCgBEgEFhIGDciTmgM7nb6cu77vtaqVZUBys8pEclrd7L1vpGeZ1yDr2q/5/3UaqmbYRiGAICIaEsPAKAeogBAEgUAkigAkEQBgCQKACRRACCJAgBpeuiD3/pn/yii293mlhvZXM/jj3//9Rj6ero2X2zi9R9+I9q2nv8fsLl/Hcf/4ivRzPrSU1KzjTj6r7No6nlNsXuxjDf/7d+LYT8pPSV13STeefuliKEpPSW1bR/37p9GU8+kmEy7eOnVd6Op5yqIpu1jde+iqvcUEXHvF7/4HZ85/DVWFISIiG47rSoIERHTWVdVECIimqNdVUGIiGi6qCoIERHd1byqIERE9Pu2qiBERDRNX91F1076qoIQEdE0Q3Xv6VCVvUoAShIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJAOjkLfN7e548am8y6ati89Y6TbTat7T8PFLIZdZZumEUMzlJ4xMl1to5nsS88YaSd9RGXvaRjaGOqaFP2+jaGuqyCGoanuPR1qeuiDb37tUzGddre55UZmi228/kNfi243Kz0lTe9exv2f+LOYr7elp6ShGaL5s0W0J7vSU1L3fBlPfv2T0Wzq+aF6dbGMb737odidrUpPSYvlJj72qW9G39fznmYnV/HS578e03k9d0F/NYvdXz6I6ayeqL/3/t34pV/6ieiHuj7I/uUvfudnDo5Ct51Ft63nAh6GJhbLXSyW9Vx2s/vncfzqs9IzRoZmiFhExKaew9lfTGL/zp3SM0a2p+s4f3Sv9IyRtu1jvqjn8o2ImN+5iruvPik9Y2R/MY/rR8elZ4z0TcTb7zwsPeMDqecTBIDiRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAdHAUmra/zR03tu8m0fdN6Rkj/fUshq6uTTFEDEPpEWPtsouY1HWepvOuujPeVXjG95tp9JWd8Wa6j2jqOuRHR5uYTvelZ3wg00Mf/Mznvxrd9uDHb117somjz/9lzJa70lO+bdPG/q2TmK63pZek62dH8Y0vvx5DX8+Pwq5r4uLx3dhfzEtPSau7F/Hj//DL0XWz0lO+7c42pj/6XswrOk+799fx1n/6wZgsutJT0rMXq/jN330tuqt6ztN+GGK+uo7rs3XpKTd28C2/WG1jsarncLb3r2L9seelZ4xdtdFeHEfsJqWXpP5qHhfffFh6xkjXtXF+eqf0jJHJoouje5elZ4zd30Tzt56WXvE37C+Wsb8oveLbLp8cx1tvvlx6xsh+GOK89IgPqJ7PRwCKEwUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUA0uFRaPtbnHFzw2YSw74pPWNsOsTQDKVXjEyX22gm+9IzRtpmiIi63tPuehb9vrJvpM0khrr+dNEe7SKmdY1aLncxaevaVNnNdCPTQx88/rt/HsP24MdvXX81jd2XX4lm2ZWekq5O1/Hm73+6rvfUR2w289herEpPSavjy/js579a1b+c9ngTw2uPYnq8KT3l2zZNxFdOolnW80H25Fv34z/+1uuxe1HPeXrebOPXjr8Zp1fz0lPS9ywi/slHJtG0dX38HOLg26tddRGrei7gZjqLfjuJ2E5KT0n7Z6s4e+tDpWeM9H1E19UTqYiIIZo4vn9ResZIe/cqVi+flZ4xNmsjni0jLuo5493TZfyPb7xUesbI42YT31icRrTb0lPSySTiwXE9f7ebqOz3MgAliQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIA6eAoDPvmNnfcWDPtI5qh9IyR2XwXTbsvPWOkaSIi6npPu+0s+srO07CdVnfGYzLEUNnf7ujkKqazrvSMkeUwiUldrylOu4htX9moA00PffDJr3wm2nk9h+H5s6P4nS99LvZXs9JTUj9EnF4uYnO5LD0lPXz4In72H/xhNDX9Jrx3Hc1PvRWz423pJWn/eBlPv/jZaOZ96Snp2bOj+M0v/Z3Yb+o648fHV3H67E7pKemjx7v4999/Hl1FH2TtZB8P7l7Gw3llHxoHODgK/eU8+sv5bW65kevHd+LdN18uPWNkP0ScdzXdvhHrk2kc378oPWOkeXgVi4+flp4x0ndt7M9WpWeMXD86ibe//tHSM0b6IWIT/+sXaB1mkyG+9149HxgREZNJF3cq23Soum4wAIoSBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQDSwVHourr6sVxuYzLZl54x0kRExFB4xdjV5aK6v91wPYmha0rPGGlXXcSkLz1jpMYz/tfqOuPX22ns93Wdp35oY6jrNR1seuiDX/wPX4jlcnubW25kNuvikx9/HNeXy9JT0uroMj71/f89mraey2Wy2kb7ytM4untVekrani7jvX/1oxGrei68y/Nl/Okf/u3Yn9ZzniIiXv3QeVxWdMaXy028+vF3om3rufFmy208+MS7cbSq5366eLGON/7g+6Kpq1XxiQOeOTgKF+eruDhf/V/M+X9rMe/iZHEay0U9F8t6vYuPvPys9IyRydF13Hv5RekZI+1mGrunR6VnjFyfruPpX71Uesb/0XxazwW8nO/ju+5flp4xMl9dx4OHZ6VnjGyu53F1uS494wOp678rAFCUKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQDo5C0wy3uePG9vsm+romRbedRr9vSs8YGbpJDH1dm9pFF9H2pWeMzGZdtJVt+mt1HfKum0Zf2Xnqu0l1m2aLXbTtvvSMD2R66IOvfPg09n09PyxWR5fxfZ/7bzGf1fPiJ4tdrF55Gov1pvSUtD1bxl/85x+ItqL39Ox0GV/6vU/H5mJRekpaz/bxibtnsd3OS09Jdx6cxud+9g9iMqknVsO+jeFyHovFtvSU9OjR3fg3/+4no63oPZ1HF1/dbeLiqp4zHhHxCwc8c3AUptMhplHPxbJe7eLBw7PSM0amq00cPTwvPWOk28xie7YuPWPk/Nk63n77u0rPGDmed/Gx4z6mk3q+zBfLXXz41UelZ4z0XRu750elZ4xMTtfx3qN7pWeMnMYu3pw9Lz3jA6nn0x+A4kQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgCQKACRRACCJAgBJFABIogBAEgUAkigAkEQBgHRwFNrJ/jZ33NhuO4t+35SeMdJ3kxj6ujZN5l00bV96xshy0cWksk27fRP9UHrF2OZqEfuuru+2phkioq4XtT7axHRa1/00izbaul7TwaaHPviT//iXY3O1uM0tN7NvYhJDLJbb0kvS6ZOT+OV//fMRk3pOw7Zr4v2n69heLktPSavFLv7+a2/F1XZeekpan1zEZ378jVgtutJT/jdDbM6XsTq+Kj0knT29E3/8Kz8U+/3BV8ete75t4/HRWZydrUtPSfOhiS90D2Lb1PXxc4iD/7Lru5exvnt5m1tuZNg30V/Xc6lERMRkiKfvPSy9YmSza+K956vSM0a2q21878unsZjuSk9JxyfX8fJHn5SeMdb0MVl1EVHPr8/9bhrP36/rjD/btPH4fBXR1PMx1kTEOiaxjknpKTdW129TAIoSBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQCSKACQRAGAJAoAJFEAIIkCAEkUAEiiAEASBQDS/79RaIaIGEqvGFmsNtFO9qVnjEzaiKay99R1bfR96RVju+tZ7Pe1/XNoYqjrTxfzCs/4sh1i0tT1ovYxRF/Zv7tDTQ9+cr6r6g6+eHISX/svPxx9Nyk9JZ1uJvHGVR8X58vSU9K95S5+7LW3Yxia0lPS8s5lfPrH3ojlois9JQ1DxOZsGYujbekp6eL5Ufzp770Ww1BPrF5sJ/Ebj6dxeb4uPSU9XG/in/7IX8SmojO+eHAWn/iZP4nlrK6ARvzCd3zi4Cg0TUTU886j30/ixXsPSs8YeXY1jfeeH5eeMbLoI47Xu9IzRo5OruJD3/2s9IyRoY/ou1lERRfLvpvGi/cflp4x8vR6Em8/uVN6xsh83sTDO5vSM0ZWDy/ik699q/SMD6SeTxAAihMFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFANLBUej3zW3uuLH5ehPtZF96xshi2kfbDKVnjFx3k+j60ivGdttZ9PvKvkeaiCHq+tvNFtto28rO+KSPSWVn/HI3ia6y+6m7mkffVXbGDzQ99ME//9XPxXSxu80tN/L8bBlfe+c4rs6PSk9J63kXP/OpJ7Gp6DAc3z+NT//0V+J4Vc/frruexZM3PxzzijY9f3YUv/vbn439flJ6Spq0Q8yHNq4vV6WnpOM7F/HPf+6PYtvXcwlP717GR7/w9bh7tC09JQ1dE7u37sT8eFN6yo0dHIXuahHd1eI2t9zIxfN1PHp8r/SMv+Fo1sfRrJ5P85M7m/jIK09KzxjZXc5j+/5J7Dez0lPS5nIZ7779UukZI5O2j7vrrvSMkX5o4sMPLkrPGJk9OI+PfPJR6Rkjw66NeLyK6Or5yDhUPZ+0ABQnCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIASRQASKIAQBIFAJIoAJBEAYAkCgCkZhiGofQIAOrglwIASRQASKIAQBIFAJIoAJBEAYAkCgAkUQAgiQIA6X8C74jHPIU0N6EAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "cells = g[env_to_plot]\n", + "# Get specific row and column\n", + "n_cell = 0\n", + "row = 0\n", + "col = 0\n", + "# Plot rate map for this cell by collection firing rate at each location\n", + "loc_rates = cells[n_cell]\n", + "plot.plot_map(environment, np.array([loc_rates[l][col] for l in range(len(loc_rates))]), shape='square')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Example Trajectory\n", + "Below are examples of the discrete random walks taken by TEM in a simple square environment. The left plot shows the trajectory of a single agent and the right shows the combined trajectories of 4 agents in a batch.\n", + "\n", + "
\n", + " \n", + " \n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. References\n", + "[1] J. C. Whittington, T. H. Muller, S. Mark, G. Chen, C. Barry, N. Burgess, and T. E. Behrens, “The tolman-eichenbaum machine: Unifying space and relational memory through generalization in the hippocampal forma-tion,” Cell, vol. 183, pp. 1249–1263.e23, Nov. 2020.\n", + "\n", + "[2] . Krupic, N. Burgess, and J. O’Keefe, “Neural representations of location composed of spatially periodic bands,\" Science, vol. 337, pp. 853–857, Aug. 2012.\n", + "\n", + "[3] E. C. Tolman, “Cognitive maps in rats and men.,” Psychological Review, vol. 55, no. 4, pp. 189–208, 1948.\n", + "\n", + "[4] S. S. Deshmukh and J. J. Knierim, “Representation of non-spatial and spatial information in the lateral entorhinal cortex,” Frontiers in Behavioral Neuroscience, vol. 5, 2011.\n", + "\n", + "[5] F. Savelli, D. Yoganarasimha, and J. J. Knierim, “Influence of boundary removal on the spatial representations\n", + "of the medial entorhinal cortex,” Hippocampus, vol. 18, pp. 1270–1282, Dec. 2008.\n", + "\n", + "[6] M. L. Shapiro, H. Tanila, and H. Eichenbaum, “Cues that hippocampal place cells encode: Dynamic and hierarchical representation of local and distal stimuli,” Hippocampus, vol. 7, no. 6, pp. 624–642, 1997." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.10.9 ('TorchTEM')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "aae7794e8b9129601af2b947d7e70a79cdb1d800d39e72db0e99f5fc7aff18fc" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/agent_examples/whittington_2020_plot.py b/examples/agent_examples/whittington_2020_plot.py new file mode 100644 index 00000000..f80251bb --- /dev/null +++ b/examples/agent_examples/whittington_2020_plot.py @@ -0,0 +1,186 @@ +# Standard Imports +import importlib.util +import pickle + +import matplotlib.pyplot as plt +import numpy as np +import torch + +import neuralplayground.agents.whittington_2020_extras.whittington_2020_analyse as analyse +from neuralplayground.agents.whittington_2020 import Whittington2020 +from neuralplayground.arenas.batch_environment import BatchEnvironment + +# NeuralPlayground Imports +from neuralplayground.arenas.discritized_objects import DiscreteObjectEnvironment + +# NeuralPlayground Experiment Imports +from neuralplayground.experiments import Sargolini2006Data + +# Select trained model +date = "2023-05-17" +run = "0" +index = "19999" +base_path = "/nfs/nhome/live/lhollingsworth/Documents/NeuralPlayground/NPG/EHC_model_comparison" +npg_path = "/nfs/nhome/live/lhollingsworth/Documents/NeuralPlayground/NPG/EHC_model_comparison/examples" +base_win_path = "H:/Documents/PhD/NeuralPlayground" +win_path = "H:/Documents/PhD/NeuralPlayground/NPG/NeuralPlayground/examples" +# Load the model: use import library to import module from specified path +model_spec = importlib.util.spec_from_file_location( + "model", win_path + "/Summaries/" + date + "/torch_run" + run + "/script/whittington_2020_model.py" +) +model = importlib.util.module_from_spec(model_spec) +model_spec.loader.exec_module(model) + +# Load the parameters of the model +params = torch.load(win_path + "/Summaries/" + date + "/torch_run" + run + "/model/params_" + index + ".pt") +# Create a new tem model with the loaded parameters +tem = model.Model(params) +# Load the model weights after training +model_weights = torch.load(win_path + "/Summaries/" + date + "/torch_run" + run + "/model/tem_" + index + ".pt") +# Set the model weights to the loaded trained model weights +tem.load_state_dict(model_weights) +# Make sure model is in evaluate mode (not crucial because it doesn't currently use dropout or batchnorm layers) +tem.eval() + +# Initialise environment parameters +batch_size = 16 +arena_x_limits = [ + [-5, 5], + [-4, 4], + [-5, 5], + [-6, 6], + [-4, 4], + [-5, 5], + [-6, 6], + [-5, 5], + [-4, 4], + [-5, 5], + [-6, 6], + [-5, 5], + [-4, 4], + [-5, 5], + [-6, 6], + [-5, 5], +] +arena_y_limits = [ + [-5, 5], + [-4, 4], + [-5, 5], + [-6, 6], + [-4, 4], + [-5, 5], + [-6, 6], + [-5, 5], + [-4, 4], + [-5, 5], + [-6, 6], + [-5, 5], + [-4, 4], + [-5, 5], + [-6, 6], + [-5, 5], +] +# arena_x_limits = [[-20,20], [-20,20], [-15,15], [-10,10], [-20,20], [-20,20], [-15,15], [-10,10], +# [-20,20], [-20,20], [-15,15], [-10,10], [-20,20], [-20,20], [-15,15], [-10,10]] +# arena_y_limits = [[-4,4], [-2,2], [-2,2], [-1,1], [-4,4], [-2,2], [-2,2], [-1,1], +# [-4,4], [-2,2], [-2,2], [-1,1], [-4,4], [-2,2], [-2,2], [-1,1]] +env_name = "env_example" +mod_name = "SimpleTEM" +time_step_size = 1 +state_density = 1 +agent_step_size = 1 / state_density +n_objects = 45 + +# Init simple 2D environment with discrtised objects +env_class = DiscreteObjectEnvironment +env = BatchEnvironment( + environment_name=env_name, + env_class=DiscreteObjectEnvironment, + batch_size=batch_size, + arena_x_limits=arena_x_limits, + arena_y_limits=arena_y_limits, + state_density=state_density, + n_objects=n_objects, + agent_step_size=agent_step_size, + use_behavioural_data=False, + data_path=None, + experiment_class=Sargolini2006Data, +) + +# Init TEM agent +agent = Whittington2020( + model_name=mod_name, + params=params, + batch_size=batch_size, + room_widths=env.room_widths, + room_depths=env.room_depths, + state_densities=env.state_densities, + use_behavioural_data=False, +) + +# # Run around environment +# observation, state = env.reset(random_state=True, custom_state=None) +# while agent.n_walk < 5000: +# if agent.n_walk % 100 == 0: +# print(agent.n_walk) +# action = agent.batch_act(observation) +# observation, state = env.step(action, normalize_step=True) +# model_input, history, environments = agent.collect_final_trajectory() +# environments = [env.collect_environment_info(model_input, history, environments)] + +# # Save environments and model_input using pickle +# with open('NPG_environments.pkl', 'wb') as f: +# pickle.dump(environments, f) +# with open('NPG_model_input.pkl', 'wb') as f: +# pickle.dump(model_input, f) + +# Load environments and model_input using pickle +with open("NPG_environments.pkl", "rb") as f: + environments = pickle.load(f) +with open("NPG_model_input.pkl", "rb") as f: + model_input = pickle.load(f) + +with torch.no_grad(): + forward = tem(model_input, prev_iter=None) +include_stay_still = False +shiny_envs = [False, False, False, False] +env_to_plot = 0 +envs_to_avg = shiny_envs if shiny_envs[env_to_plot] else [not shiny_env for shiny_env in shiny_envs] + +correct_model, correct_node, correct_edge = analyse.compare_to_agents( + forward, tem, environments, include_stay_still=include_stay_still +) +zero_shot = analyse.zero_shot(forward, tem, environments, include_stay_still=include_stay_still) +occupation = analyse.location_occupation(forward, tem, environments) +g, p = analyse.rate_map(forward, tem, environments) +from_acc, to_acc = analyse.location_accuracy(forward, tem, environments) + +# Plot rate maps for grid or place cells +agent.plot_rate_map(g) + +# Plot results of agent comparison and zero-shot inference analysis +filt_size = 41 +plt.figure() +plt.plot( + analyse.smooth( + np.mean(np.array([env for env_i, env in enumerate(correct_model) if envs_to_avg[env_i]]), 0)[1:], filt_size + ), + label="tem", +) +plt.plot( + analyse.smooth(np.mean(np.array([env for env_i, env in enumerate(correct_node) if envs_to_avg[env_i]]), 0)[1:], filt_size), + label="node", +) +plt.plot( + analyse.smooth(np.mean(np.array([env for env_i, env in enumerate(correct_edge) if envs_to_avg[env_i]]), 0)[1:], filt_size), + label="edge", +) +plt.ylim(0, 1) +plt.legend() +plt.title( + "Zero-shot inference: " + + str(np.mean([np.mean(env) for env_i, env in enumerate(zero_shot) if envs_to_avg[env_i]]) * 100) + + "%" +) + +# plt.show() diff --git a/examples/agent_examples/whittington_2020_run.py b/examples/agent_examples/whittington_2020_run.py new file mode 100644 index 00000000..a0a30611 --- /dev/null +++ b/examples/agent_examples/whittington_2020_run.py @@ -0,0 +1,124 @@ +""" +Run file for the Tolman-Eichenbaum Machine (TEM) model from Whittington et al. 2020. An example setup is provided, with +TEM learning to predict upcoming sensory stimulus in a range of 16 square environments of varying sizes. +""" + +# Standard Imports + +import matplotlib.pyplot as plt + +# NeuralPlayground Agent Imports +import neuralplayground.agents.whittington_2020_extras.whittington_2020_parameters as parameters +from neuralplayground.agents.whittington_2020 import Whittington2020 +from neuralplayground.arenas.batch_environment import BatchEnvironment + +# NeuralPlayground Arena Imports +from neuralplayground.arenas.discritized_objects import DiscreteObjectEnvironment + +# NeuralPlayground Experiment Imports +from neuralplayground.experiments import Sargolini2006Data + +# Initialise TEM Parameters +pars_orig = parameters.parameters() +params = pars_orig.copy() + +# Initialise environment parameters +batch_size = 16 +arena_x_limits = [ + [-5, 5], + [-4, 4], + [-5, 5], + [-6, 6], + [-4, 4], + [-5, 5], + [-6, 6], + [-5, 5], + [-4, 4], + [-5, 5], + [-6, 6], + [-5, 5], + [-4, 4], + [-5, 5], + [-6, 6], + [-5, 5], +] +arena_y_limits = [ + [-5, 5], + [-4, 4], + [-5, 5], + [-6, 6], + [-4, 4], + [-5, 5], + [-6, 6], + [-5, 5], + [-4, 4], + [-5, 5], + [-6, 6], + [-5, 5], + [-4, 4], + [-5, 5], + [-6, 6], + [-5, 5], +] +# arena_x_limits = [[-20,20], [-20,20], [-15,15], [-10,10], [-20,20], [-20,20], [-15,15], [-10,10], +# [-20,20], [-20,20], [-15,15], [-10,10], [-20,20], [-20,20], [-15,15], [-10,10]] +# arena_y_limits = [[-4,4], [-2,2], [-2,2], [-1,1], [-4,4], [-2,2], [-2,2], [-1,1], +# [-4,4], [-2,2], [-2,2], [-1,1], [-4,4], [-2,2], [-2,2], [-1,1]] +env_name = "Sargolini2006" +mod_name = "SimpleTEM" +time_step_size = 1 +state_density = 1 +agent_step_size = 1 / state_density +n_objects = 45 + +# # Init environment from Hafting 2008 (optional, if chosen, comment out the ) +# env = Hafting2008(agent_step_size=agent_step_size, +# time_step_size=time_step_size, +# use_behavioral_data=False) + +# # Init simple 2D (batched) environment with discrtised objects +# env_class = DiscreteObjectEnvironment + +# Init environment from Sargolini, with behavioural data instead of random walk +env = BatchEnvironment( + environment_name=env_name, + env_class=DiscreteObjectEnvironment, + batch_size=batch_size, + arena_x_limits=arena_x_limits, + arena_y_limits=arena_y_limits, + state_density=state_density, + n_objects=n_objects, + agent_step_size=agent_step_size, + use_behavioural_data=False, + data_path=None, + experiment_class=Sargolini2006Data, +) + +# Init TEM agent +agent = Whittington2020( + model_name=mod_name, + params=params, + batch_size=batch_size, + room_widths=env.room_widths, + room_depths=env.room_depths, + state_densities=env.state_densities, + use_behavioural_data=False, +) + +# Reset environment and begin training (random_state=True is currently necessary) +observation, state = env.reset(random_state=True, custom_state=None) +for i in range(3): + print("Iteration: ", i) + while agent.n_walk < params["n_rollout"]: + actions = agent.batch_act(observation) + observation, state = env.step(actions, normalize_step=True) + agent.update() + +# Plot most recent trajectory of the first environment in batch +ax = env.plot_trajectory() +fontsize = 18 +ax.grid() +ax.set_xlabel("width", fontsize=fontsize) +ax.set_ylabel("depth", fontsize=fontsize) +plt.savefig("trajectory.png") +plt.show() diff --git a/neuralplayground/agents/whittington_2020.py b/neuralplayground/agents/whittington_2020.py new file mode 100644 index 00000000..a9faf84d --- /dev/null +++ b/neuralplayground/agents/whittington_2020.py @@ -0,0 +1,525 @@ +import copy +import os +import shutil +import sys +import time + +import matplotlib.pyplot as plt +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +import neuralplayground.agents.whittington_2020_extras.whittington_2020_model as model +import neuralplayground.agents.whittington_2020_extras.whittington_2020_parameters as parameters + +# Custom modules +import neuralplayground.agents.whittington_2020_extras.whittington_2020_utils as utils +from neuralplayground.plotting.plot_utils import make_plot_rate_map + +from .agent_core import AgentCore + +sys.path.append("../") + + +class Whittington2020(AgentCore): + """ + Implementation of TEM 2020 by James C.R. Whittington, Timothy H. Muller, Shirley Mark, Guifen Chen, Caswell Barry, + Neil Burgess, Timothy E.J. Behrens. The Tolman-Eichenbaum Machine: Unifying Space and Relational Memory through + Generalization in the Hippocampal Formation https://doi.org/10.1016/j.cell.2020.10.024. + ---- + Attributes + --------- + mod_kwargs : dict + Model parameters + params: dict + contains the majority of parameters used by the model and environment + room_width: float + room width specified by the environment (see examples/examples/whittington_2020_example.ipynb) + room_depth: float + room depth specified by the environment (see examples/examples/whittington_2020_example.ipynb) + state_density: float + density of agent states (should be proportional to the step-size) + tem: class + TEM model + + Methods + --------- + reset(self): + initialise model and associated variables for training + def initialise(self): + generate random distribution of objects and intialise optimiser, logger and relevant variables + act(self, positions, policy_func): + generates batch of random actions to be passed into the environment. If the returned positions are allowed, + they are saved along with corresponding actions + update(self): + Perform forward pass of model and calculate losses and accuracies + action_policy(self): + random action policy that picks from [up, down, left right] + discretise(self, step): + convert (x,y) position into discrete location + walk(self, positions): + convert continuous positions into sequence of discrete locations + make_observations(self, locations): + observe what randomly distributed object is located at each position of a walk + step_to_actions(self, actions): + convert (x,y) action information into an integer value + """ + + def __init__(self, model_name: str = "TEM", **mod_kwargs): + """ + Parameters + ---------- + model_name : str + Name of the specific instantiation of the ExcInhPlasticity class + mod_kwargs : dict + params: dict + contains the majority of parameters used by the model and environment + room_width: float + room width specified by the environment (see examples/examples/whittington_2020_example.ipynb) + room_depth: float + room depth specified by the environment (see examples/examples/whittington_2020_example.ipynb) + state_density: float + density of agent states (should be proportional to the step-size) + """ + super().__init__() + params = mod_kwargs["params"] + self.room_widths = mod_kwargs["room_widths"] + self.room_depths = mod_kwargs["room_depths"] + self.state_densities = mod_kwargs["state_densities"] + self.pars = copy.deepcopy(params) + self.tem = model.Model(self.pars) + self.batch_size = mod_kwargs["batch_size"] + self.use_behavioural_data = mod_kwargs["use_behavioural_data"] + self.n_envs_save = 4 + self.n_states = [ + int(self.room_widths[i] * self.room_depths[i] * self.state_densities[i]) for i in range(self.batch_size) + ] + self.poss_actions = [[0, 0], [0, -1], [1, 0], [0, 1], [-1, 0]] + self.n_actions = len(self.poss_actions) + self.final_model_input = None + + self.prev_observations = None + self.reset() + + def reset(self): + """ + initialise model and associated variables for training, set n_walk=-1 initially to account for the lack of + actions at initialisation + """ + self.tem = model.Model(self.pars) + self.initialise() + self.n_walk = -1 + self.final_model_input = None + self.obs_history = [] + self.walk_actions = [] + self.walk_action_values = [] + self.prev_action = None + self.prev_observation = None + self.prev_actions = [[None, None] for _ in range(self.batch_size)] + self.prev_observations = [[-1, -1, [float("inf"), float("inf")]] for _ in range(self.batch_size)] + + def act(self, observation, policy_func=None): + """ + The base model executes one of four action (up-down-right-left) with equal probability. + This is used to move on the rectangular environment states space (transmat). + This is done for a single environment. + Parameters + ---------- + positions: array (16,2) + Observation from the environment class needed to choose the right action (Here the position). + Returns + ------- + action : array (16,2) + Action values (Direction of the agent step) in this case executes one of four action + (up-down-right-left) with equal probability. + """ + new_action = self.action_policy() + if observation[0] == self.prev_observation[0]: + self.prev_action = new_action + else: + self.walk_actions.append(self.prev_action) + self.obs_history.append(self.prev_observation) + self.prev_action = new_action + self.prev_observation = observation + self.n_walk += 1 + + return new_action + + def batch_act(self, observations, policy_func=None): + """ + The base model executes one of four action (up-down-right-left) with equal probability. + This is used to move on the rectangular environment states space (transmat). + This is done for a batch of 16 environments. + Parameters + ---------- + observations: array (16,3/4) + Observation from the environment class needed to choose the right action + (here the state ID and position). If behavioural data is used, the observation includes head direction. + Returns + ------- + new_actions : array (16,2) + Action values (direction of the agent step) in this case executes one of four action + (up-down-right-left) with equal probability. + """ + + if self.use_behavioural_data: + state_diffs = [observations[i][0] - self.prev_observations[i][0] for i in range(self.batch_size)] + new_actions = self.infer_action(state_diffs) + self.walk_actions.append(new_actions) + self.obs_history.append(self.prev_observations.copy()) + self.prev_observations = observations + self.n_walk += 1 + + elif not self.use_behavioural_data: + locations = [env[0] for env in observations] + all_allowed = True + new_actions = [] + for i, loc in enumerate(locations): + if loc == self.prev_observations[i][0] and self.prev_actions[i] != [0, 0]: + all_allowed = False + break + + if all_allowed: + self.walk_actions.append(self.prev_actions.copy()) + self.obs_history.append(self.prev_observations.copy()) + for batch in range(self.pars["batch_size"]): + new_actions.append(self.action_policy()) + self.prev_actions = new_actions + self.prev_observations = observations + self.n_walk += 1 + + elif not all_allowed: + for i, loc in enumerate(locations): + if loc == self.prev_observations[i][0]: + new_actions.append(self.action_policy()) + else: + new_actions.append(self.prev_actions[i]) + self.prev_actions = new_actions + + return new_actions + + def update(self): + """ + Compute forward pass through model, updating weights, calculating TEM variables and collecting + losses / accuracies + """ + self.iter = int((len(self.obs_history) / 20) - 1) + self.global_steps += 1 + history = self.obs_history[-self.pars["n_rollout"] :] + locations = [[{"id": env_step[0], "shiny": None} for env_step in step] for step in history] + observations = [[env_step[1] for env_step in step] for step in history] + actions = self.walk_actions[-self.pars["n_rollout"] :] + self.n_walk = 0 + # Convert action vectors to action values + action_values = self.step_to_actions(actions) + self.walk_action_values.append(action_values) + # Get start time for function timing + start_time = time.time() + # Get updated parameters for this backprop iteration + ( + self.eta_new, + self.lambda_new, + self.p2g_scale_offset, + self.lr, + self.walk_length_center, + loss_weights, + ) = parameters.parameter_iteration(self.iter, self.pars) + # Update eta and lambda + self.tem.hyper["eta"] = self.eta_new + self.tem.hyper["lambda"] = self.lambda_new + # Update scaling of offset for variance of inferred grounded position + self.tem.hyper["p2g_scale_offset"] = self.p2g_scale_offset + # Update learning rate (the neater torch-way of doing this would be a scheduler, but this is quick and easy) + for param_group in self.adam.param_groups: + param_group["lr"] = self.lr + + # Collect all information in walk variable + model_input = [ + [ + locations[i], + torch.from_numpy(np.reshape(observations, (20, 16, 45))[i]).type(torch.float32), + np.reshape(action_values, (20, 16))[i].tolist(), + ] + for i in range(self.pars["n_rollout"]) + ] + self.final_model_input = model_input + + forward = self.tem(model_input, self.prev_iter) + + # Accumulate loss from forward pass + loss = torch.tensor(0.0) + # Make vector for plotting losses + plot_loss = 0 + # Collect all losses / variables + for ind, step in enumerate(forward): + # Make list of losses included in this step + step_loss = [] + # Only include loss for locations that have been visited before + for env_i, env_visited in enumerate(self.visited): + if env_visited[step.g[env_i]["id"]]: + step_loss.append(loss_weights * torch.stack([i[env_i] for i in step.L])) + else: + env_visited[step.g[env_i]["id"]] = True + step_loss = torch.tensor(0) if not step_loss else torch.mean(torch.stack(step_loss, dim=0), dim=0) + # Save all separate components of loss for monitoring + plot_loss = plot_loss + step_loss.detach().numpy() + # And sum all components, then add them to total loss of this step + loss = loss + torch.sum(step_loss) + + # Reset gradients + self.adam.zero_grad() + # Do backward pass to calculate gradients with respect to total loss of this chunk + loss.backward(retain_graph=True) + # Then do optimiser step to update parameters of model + self.adam.step() + # Update the previous iteration for the next chunk with the final step of this chunk, removing all operation history + self.prev_iter = [forward[-1].detach()] + + # Compute model accuracies + acc_p, acc_g, acc_gt = np.mean([[np.mean(a) for a in step.correct()] for step in forward], axis=0) + acc_p, acc_g, acc_gt = [a * 100 for a in (acc_p, acc_g, acc_gt)] + # Log progress + if self.iter % 10 == 0: + # Write series of messages to logger from this backprop iteration + self.logger.info("Finished backprop iter {:d} in {:.2f} seconds.".format(self.iter, time.time() - start_time)) + self.logger.info( + "Loss: {:.2f}. {:.2f} {:.2f} {:.2f} {:.2f} {:.2f} {:.2f} \ + {:.2f} {:.2f}".format( + loss.detach().numpy(), *plot_loss + ) + ) + self.logger.info("Accuracy:

{:.2f}% {:.2f}% {:.2f}%".format(acc_p, acc_g, acc_gt)) + self.logger.info( + "Parameters: {:.2f} {:.2f} {:.2f} {:.2f}".format( + np.max(np.abs(self.prev_iter[0].M[0].numpy())), + self.tem.hyper["eta"], + self.tem.hyper["lambda"], + self.tem.hyper["p2g_scale_offset"], + ) + ) + self.logger.info("Weights:" + str([w for w in loss_weights.numpy()])) + self.logger.info(" ") + # Also store the internal state (all learnable parameters) and the hyperparameters periodically + if self.iter % self.pars["save_interval"] == 0: + torch.save(self.tem.state_dict(), self.model_path + "/tem_" + str(self.iter) + ".pt") + torch.save(self.tem.hyper, self.model_path + "/params_" + str(self.iter) + ".pt") + + # Save the final state of the model after training has finished + if self.iter == self.pars["train_it"] - 1: + torch.save(self.tem.state_dict(), self.model_path + "/tem_" + str(self.iter) + ".pt") + torch.save(self.tem.hyper, self.model_path + "/params_" + str(self.iter) + ".pt") + + def initialise(self): + """ + Generate random distribution of objects and intialise optimiser, logger and relevant variables + """ + # Create directories for storing all information about the current run + ( + self.run_path, + self.train_path, + self.model_path, + self.save_path, + self.script_path, + self.envs_path, + ) = utils.make_directories() + # Save all python files in current directory to script directory + self.save_files() + # Save parameters + np.save(os.path.join(self.save_path, "params"), self.pars) + # Create a tensor board to stay updated on training progress. Start tensorboard with tensorboard --logdir=runs + self.writer = SummaryWriter(self.train_path) + # Create a logger to write log output to file + self.logger = utils.make_logger(self.run_path) + # Make an ADAM optimizer for TEM + self.adam = torch.optim.Adam(self.tem.parameters(), lr=self.pars["lr_max"]) + # Initialise whether a state has been visited for each world + self.visited = [[False for _ in range(self.n_states[env])] for env in range(self.pars["batch_size"])] + self.prev_iter = None + + def save_files(self): + """ + Save all python files in current directory to script directory + """ + curr_path = os.path.dirname(os.path.abspath(__file__)) + shutil.copy2( + os.path.abspath(os.path.join(os.getcwd(), os.path.abspath(os.path.join(curr_path, os.pardir)))) + + "/agents/whittington_2020_extras/whittington_2020_model.py", + os.path.join(self.script_path, "whittington_2020_model.py"), + ) + shutil.copy2( + os.path.abspath(os.path.join(os.getcwd(), os.path.abspath(os.path.join(curr_path, os.pardir)))) + + "/agents/whittington_2020_extras/whittington_2020_parameters.py", + os.path.join(self.script_path, "whittington_2020_parameters.py"), + ) + shutil.copy2( + os.path.abspath(os.path.join(os.getcwd(), os.path.abspath(os.path.join(curr_path, os.pardir)))) + + "/agents/whittington_2020_extras/whittington_2020_analyse.py", + os.path.join(self.script_path, "whittington_2020_analyse.py"), + ) + shutil.copy2( + os.path.abspath(os.path.join(os.getcwd(), os.path.abspath(os.path.join(curr_path, os.pardir)))) + + "/agents/whittington_2020_extras/whittington_2020_plot.py", + os.path.join(self.script_path, "whittington_2020_plot.py"), + ) + shutil.copy2( + os.path.abspath(os.path.join(os.getcwd(), os.path.abspath(os.path.join(curr_path, os.pardir)))) + + "/agents/whittington_2020_extras/whittington_2020_utils.py", + os.path.join(self.script_path, "whittington_2020_utils.py"), + ) + return + + def action_policy(self): + """ + Random action policy that selects an action to take from [stay, up, down, left, right] + """ + arrow = self.poss_actions + index = np.random.choice(len(arrow)) + action = arrow[index] + return action + + def step_to_actions(self, actions): + """ + Convert trajectory of (x,y) actions into integer values (i.e. from [[0,0],[0,-1],[1,0],[0,1],[-1,0]] to [0,1,2,3,4]) + + Parameters: + ------ + actions: (16,20,2) + batch of 16 actions for each step in a walk of length 20 + + Returns: + ------ + action_values: (16,20,1) + batch of 16 action values for each step in walk of length 20 + """ + action_values = [] + # actions = np.reshape(actions, (pars['n_rollout'], pars['batch_size'], 2)) + for steps in actions: + env_list = [] + for action in steps: + env_list.append(self.poss_actions.index(list(action))) + action_values.append(env_list) + return action_values + + def infer_action(self, state_diffs): + """ + Infers the action taken between state indices based on the difference between states. + + Parameters + ---------- + state_diff: int + The difference between the state indices. + environment_width: int + The width of the environment (number of states per row). + + Returns + ------- + action: str + The inferred action ('N', 'S', 'W', or 'E') based on the state difference. + """ + actions = [] + for i in range(self.batch_size): + if state_diffs[i] == -self.room_widths[i]: + actions.append([0, 1]) + elif state_diffs[i] == self.room_widths[i]: + actions.append([0, -1]) + elif state_diffs[i] == -1: + actions.append([-1, 0]) + elif state_diffs == 1: + actions.append([1, 0]) + else: + actions.append([0, 0]) + + return actions + + def collect_final_trajectory(self): + """ + Collect the final trajectory of the agent, including the locations, observations and actions taken. + """ + final_model_input = [] + environments = [[], self.n_actions, self.n_states[0], len(self.obs_history[-1][0][1])] + history = self.obs_history[-self.n_walk :] + locations = [[{"id": env_step[0], "shiny": None} for env_step in step] for step in history] + observations = [[env_step[1] for env_step in step] for step in history] + actions = self.walk_actions[-self.n_walk :] + action_values = self.step_to_actions(actions) + + model_input = [ + [ + locations[i], + torch.from_numpy(np.reshape(observations, (self.n_walk, 16, 45))[i]).type(torch.float32), + np.reshape(action_values, (self.n_walk, 16))[i].tolist(), + ] + for i in range(self.n_walk) + ] + + single_index = [[model_input[step][0][0]] for step in range(len(model_input))] + single_obs = [torch.unsqueeze(model_input[step][1][0], dim=0) for step in range(len(model_input))] + single_action = [[model_input[step][2][0]] for step in range(len(model_input))] + single_model_input = [[single_index[step], single_obs[step], single_action[step]] for step in range(len(model_input))] + final_model_input.extend(single_model_input) + + return final_model_input, history, environments + + def plot_rate_map(self, rate_maps): + """ + Plot the TEM rate maps. + + Parameters + ---------- + rate_maps: ndarray, shape (5, N) + The rate maps for TEM, where N is the number of cells in each frequency. + + Returns + ------- + figs: list + A list of matplotlib figures containing the rate map plots for each frequency. + axes: list + A list of arrays of matplotlib axes containing the individual rate map plots for each frequency. + """ + frequencies = ["Theta", "Delta", "Beta", "Gamma", "High Gamma"] + figs = [] + axes = [] + + for i in range(5): + n_cells = rate_maps[0][i].shape[1] + num_cols = 6 # Number of subplots per row + num_rows = np.ceil(n_cells / num_cols).astype(int) + + # Create the figure for the current frequency + fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(15, 10)) + fig.suptitle(f"{frequencies[i]} Rate Maps", fontsize=16) + + # Create a single colorbar for the entire figure + cbar_ax = fig.add_axes([0.91, 0.15, 0.02, 0.7]) + + # Create the subplots for the current frequency + for j in range(n_cells): + if j >= n_cells: + break + ax_row = j // num_cols + ax_col = j % num_cols + + # Get the rate map for the current cell and frequency + rate_map = np.asarray(rate_maps[0][i]).T[j] + + # Reshape the rate map into a matrix + rate_map_mat = np.reshape(rate_map, (self.room_widths[0], self.room_depths[0])) + + # Plot the rate map in the corresponding subplot + title = f"Cell {j+1}" + make_plot_rate_map(rate_map_mat, axs[ax_row, ax_col], title, "", "", "") + + # Hide unused subplots for the current frequency + for j in range(n_cells, num_rows * num_cols): + ax_row = j // num_cols + ax_col = j % num_cols + axs[ax_row, ax_col].axis("off") + + # Add a single colorbar to the figure + cbar = fig.colorbar(axs[0, 0].get_images()[0], cax=cbar_ax) + cbar.set_label("Firing rate", fontsize=14) + + figs.append(fig) + axes.append(axs) + + return figs, axes diff --git a/neuralplayground/agents/whittington_2020_extras/whittington_2020_analyse.py b/neuralplayground/agents/whittington_2020_extras/whittington_2020_analyse.py new file mode 100644 index 00000000..ea6be1ce --- /dev/null +++ b/neuralplayground/agents/whittington_2020_extras/whittington_2020_analyse.py @@ -0,0 +1,401 @@ +import numpy as np +import torch + + +def performance(forward, model, environments): + """ + Track prediction accuracy over walk, and calculate fraction of locations visited and actions taken to assess performance. + Parameters + ---------- + forward : list + List of forward passes through the model, each containing the model input, the model output, and the model state. + model : TEM + The model that was used to generate the forward passes. + environments : list + List of environments that were used to generate the forward passes. + Returns + ------- + all_correct : list + List of lists of booleans, indicating for each step whether the model predicted the observation correctly. + all_location_frac : list + List of lists of floats, indicating for each step the fraction of locations visited. + all_action_frac : list + List of lists of floats, indicating for each step the fraction of actions taken. + """ + # Keep track of whether model prediction were correct, as well as the fraction of nodes/edges visited, across environments + all_correct, all_location_frac, all_action_frac = [], [], [] + # Run through environments and monitor performance in each + for env_i, env in enumerate(environments): + # Keep track for each location whether it has been visited + location_visited = np.full(env.n_locations, False) + # And for each action in each location whether it has been taken + action_taken = np.full((env.n_locations, model.hyper["n_actions"]), False) + # Not all actions are available at every location (e.g. edges of grid world). Find how many actions can be taken + action_available = np.full((env.n_locations, model.hyper["n_actions"]), False) + for currLocation in env.locations: + for currAction in currLocation["actions"]: + if np.sum(currAction["transition"]) > 0: + if model.hyper["has_static_action"]: + if currAction["id"] > 0: + action_available[currLocation["id"], currAction["id"] - 1] = True + else: + action_available[currLocation["id"], currAction["id"]] = True + # Make array to list whether the observation was predicted correctly or not + correct = [] + # Make array that stores for each step the fraction of locations visited + location_frac = [] + # And an array that stores for each step the fraction of actions taken + action_frac = [] + # Run through iterations of forward pass to check when an action is taken for the first time + for step in forward: + # Update the states that have now been visited + location_visited[step.g[env_i]["id"]] = True + # ... And the actions that now have been taken + if model.hyper["has_static_action"]: + if step.a[env_i] > 0: + action_taken[step.g[env_i]["id"], step.a[env_i] - 1] = True + else: + action_taken[step.g[env_i]["id"], step.a[env_i]] = True + # Mark the location of the previous iteration as visited + correct.append((torch.argmax(step.x_gen[2][env_i]) == torch.argmax(step.x[env_i])).numpy()) + # Add the fraction of locations visited for this step + location_frac.append(np.sum(location_visited) / location_visited.size) + # ... And also add the fraction of actions taken for this step + action_frac.append(np.sum(action_taken) / np.sum(action_available)) + # Add performance and visitation fractions of this environment to performance list across environments + all_correct.append(correct) + all_location_frac.append(location_frac) + all_action_frac.append(action_frac) + # Return + return all_correct, all_location_frac, all_action_frac + + +def location_accuracy(forward, model, environments): + """ + Track prediction accuracy per location, after a transition towards the location. + Parameters + ---------- + forward : list + List of forward passes through the model, each containing the model input, the model output, and the model state. + model : TEM + The model that was used to generate the forward passes. + environments : list + List of environments that were used to generate the forward passes. + Returns + ------- + accuracy_from : list + List of lists of floats, indicating for each location the fraction of correct predictions after arriving at + that location. + accuracy_to : list + List of lists of floats, indicating for each location the fraction of correct predictions after leaving + that location. + """ + # Keep track of whether model prediction were correct for each environment, separated by arrival and departure location + accuracy_from, accuracy_to = [], [] + # Run through environments and monitor performance in each + for env_i, env in enumerate(environments): + # Make array to list whether the observation was predicted correctly or not + correct_from = [[] for _ in range(env[2])] + correct_to = [[] for _ in range(env[2])] + # Run through iterations of forward pass to check when an action is taken for the first time + for step_i, step in enumerate(forward[1:]): + # Prediction on arrival: sensory prediction when arriving at given node + correct_to[step.g[env_i]["id"]].append( + (torch.argmax(step.x_gen[2][env_i]) == torch.argmax(step.x[env_i])).numpy().tolist() + ) + correct_from[forward[step_i].g[env_i]["id"]].append( + (torch.argmax(step.x_gen[2][env_i]) == torch.argmax(step.x[env_i])).numpy().tolist() + ) + # Add performance and visitation fractions of this environment to performance list across environments + accuracy_from.append( + [ + sum(correct_from_location) / (len(correct_from_location) if len(correct_from_location) > 0 else 1) + for correct_from_location in correct_from + ] + ) + accuracy_to.append( + [ + sum(correct_to_location) / (len(correct_to_location) if len(correct_to_location) > 0 else 1) + for correct_to_location in correct_to + ] + ) + # Return + return accuracy_from, accuracy_to + + +def location_occupation(forward, model, environments): + """ + Track how often each location was visited during the walk. + Parameters + ---------- + forward : list + List of forward passes through the model, each containing the model input, the model output, and the model state. + model : TEM + The model that was used to generate the forward passes. + environments : list + List of environments that were used to generate the forward passes. + Returns + ------- + occupation : list + List of lists of integers, indicating for each location how often it was visited during the walk. + """ + # Keep track of how many times each location was visited + occupation = [] + # Run through environments and monitor performance in each + for env_i, env in enumerate(environments): + # Make array to list whether the observation was predicted correctly or not + visits = [0 for _ in range(env[2])] + # Run through iterations of forward pass to check when an action is taken for the first time + for step in forward: + # Prediction on arrival: sensory prediction when arriving at given node + visits[step.g[env_i]["id"]] += 1 + # Add performance and visitation fractions of this environment to performance list across environments + occupation.append(visits) + # Return occupation of states during walk across environments + return occupation + + +def zero_shot(forward, model, environments, include_stay_still=True): + """ + Track whether the model can predict the observation correctly when it visits a location for the first time. + Parameters + ---------- + forward : list + List of forward passes through the model, each containing the model input, the model output, and the model state. + model : TEM + The model that was used to generate the forward passes. + environments : list + List of environments that were used to generate the forward passes. + include_stay_still : bool + Whether to include standing still actions in the zero-shot inference analysis. + Returns + ------- + all_correct_zero_shot : list + List of lists of booleans, indicating for each step whether the model predicted the observation correctly + when visiting a location for the first time. + """ + # Get the number of actions in this model + n_actions = model.hyper["n_actions"] + model.hyper["has_static_action"] + # Track for all opportunities for zero-shot inference if the predictions were correct across environments + all_correct_zero_shot = [] + # Run through environments and check for zero-shot inference in each of them + for env_i, env in enumerate(environments): + # Keep track for each location whether it has been visited + location_visited = np.full(env[2], False) + # And for each action in each location whether it has been taken + action_taken = np.full((env[2], n_actions), False) + # Get the very first iteration + prev_iter = forward[0] + # Make list that for all opportunities for zero-shot inference tracks if the predictions were correct + correct_zero_shot = [] + # Run through iterations of forward pass to check when an action is taken for the first time + for step in forward[1:]: + # Get the previous action and previous location location + prev_a, prev_g = prev_iter.a[env_i], prev_iter.g[env_i]["id"] + if model.hyper["has_static_action"] and prev_a == 0 and not include_stay_still: + prev_a = None + # Mark the location of the previous iteration as visited + location_visited[prev_g] = True + # Zero shot inference occurs when the current location was visited, but the previous action wasn't taken before + if location_visited[step.g[env_i]["id"]] and prev_a is not None and not action_taken[prev_g, prev_a]: + # Find whether the prediction was correct + correct_zero_shot.append((torch.argmax(step.x_gen[2][env_i]) == torch.argmax(step.x[env_i])).numpy()) + # Update the previous action as taken + if prev_a is not None: + action_taken[prev_g, prev_a] = True + # And update the previous iteration to the current iteration + prev_iter = step + # Having gone through the full forward pass for one environment, add the zero-shot performance to the list of all + all_correct_zero_shot.append(correct_zero_shot) + # Return lists of success of zero-shot inference for all environments + return all_correct_zero_shot + + +def compare_to_agents(forward, model, environments, include_stay_still=True): + """ + Compare TEM performance to a 'node' and an 'edge' agent, that remember previous observations and guess others. + Parameters + ---------- + forward : list + List of forward passes through the model, each containing the model input, the model output, and the model state. + model : TEM + The model that was used to generate the forward passes. + environments : list + List of environments that were used to generate the forward passes. + include_stay_still : bool + Whether to include standing still actions in the zero-shot inference analysis. + Returns + ------- + all_correct_model : list + List of lists of booleans, indicating for each step whether the model predicted the observation correctly. + all_correct_node : list + List of lists of booleans, indicating for each step whether the node agent predicted the observation correctly. + all_correct_edge : list + List of lists of booleans, indicating for each step whether the edge agent predicted the observation correctly. + """ + # Get the number of actions in this model + n_actions = model.hyper["n_actions"] + model.hyper["has_static_action"] + # Store for each environment for each step whether is was predicted correctly by the model, and by a perfect node and + # perfect edge agent + all_correct_model, all_correct_node, all_correct_edge = [], [], [] + # Run through environments and check for correct or incorrect prediction + for env_i, env in enumerate(environments): + # Keep track for each location whether it has been visited + location_visited = np.full(env[2], False) + # And for each action in each location whether it has been taken + action_taken = np.full((env[2], n_actions), False) + # Make array to list whether the observation was predicted correctly or not for the model + correct_model = [] + # And the same for a node agent, that picks a random observation on first encounter of a node, and the correct + # one every next time + correct_node = [] + # And the same for an edge agent, that picks a random observation on first encounter of an edge, and the correct + # one every next time + correct_edge = [] + # Get the very first iteration + prev_iter = forward[0] + # Run through iterations of forward pass to check when an action is taken for the first time + for step in forward[1:]: + # Get the previous action and previous location + prev_a, prev_g = prev_iter.a[env_i], prev_iter.g[env_i]["id"] + # If the previous action was standing still: only count as valid transition standing still actions + # are included as zero-shot inference + if model.hyper["has_static_action"] and prev_a == 0 and not include_stay_still: + prev_a = None + # Mark the location of the previous iteration as visited + location_visited[prev_g] = True + # Update model prediction for this step + correct_model.append((torch.argmax(step.x_gen[2][env_i]) == torch.argmax(step.x[env_i])).numpy()) + # Update node agent prediction for this step: correct when this state was visited beofre, otherwise chance + correct_node.append( + True + if location_visited[step.g[env_i]["id"]] + else np.random.randint(model.hyper["n_x"]) == torch.argmax(step.x[env_i]).numpy() + ) + # Update edge agent prediction for this step: always correct if no action taken, correct when action leading + # to this state was taken before, otherwise chance + correct_edge.append( + True + if prev_a is None + else True + if action_taken[prev_g, prev_a] + else np.random.randint(model.hyper["n_x"]) == torch.argmax(step.x[env_i]).numpy() + ) + # Update the previous action as taken + if prev_a is not None: + action_taken[prev_g, prev_a] = True + # And update the previous iteration to the current iteration + prev_iter = step + # Add the performance of model, node agent, and edge agent for this environment to list across environments + all_correct_model.append(correct_model) + all_correct_node.append(correct_node) + all_correct_edge.append(correct_edge) + # Return list of prediction success for all three agents across environments + return all_correct_model, all_correct_node, all_correct_edge + + +def rate_map(forward, model, environments): + """ + Calculate the firing rate of each cell in the model for each location in each environment. + Parameters + ---------- + forward : list + List of forward passes through the model, each containing the model input, the model output, and the model state. + model : TEM + The model that was used to generate the forward passes. + environments : list + List of environments that were used to generate the forward passes. + Returns + ------- + all_g : list + List of lists of lists of floats, indicating for each frequency module, for each location, and for each + cell the firing rate. + all_p : list + List of lists of lists of floats, indicating for each frequency module, for each location, and for each + cell the firing rate. + """ + # Store location x cell firing rate matrix for abstract and grounded location representation across environments + all_g, all_p = [], [] + # Go through environments and collect firing rates in each + for env_i, env in enumerate(environments): + # Collect grounded location/hippocampal/place cell representation during walk: separate into frequency + # modules, then locations + p = [[[] for loc in range(env[2])] for f in range(model.hyper["n_f"])] + # Collect abstract location/entorhinal/grid cell representation during walk: separate into frequency + # modules, then locations + g = [[[] for loc in range(env[2])] for f in range(model.hyper["n_f"])] + # In each step, concatenate the representations to the appropriate list + for step in forward: + # Run through frequency modules and append the firing rates to the correct location list + for f in range(model.hyper["n_f"]): + g[f][step.g[env_i]["id"]].append(step.g_inf[f][env_i].detach().numpy()) + p[f][step.g[env_i]["id"]].append(step.p_inf[f][env_i].detach().numpy()) + # Now average across location visits to get a single represenation vector for each location for each frequency + for cells, n_cells in zip([p, g], [model.hyper["n_p"], model.hyper["n_g"]]): + for f, frequency in enumerate(cells): + # Average across visits of the each location, but only the second half of the visits so model + # roughly know the environment + for i, location in enumerate(frequency): + frequency[i] = ( + sum(location[int(len(location) / 2) :]) / len(location[int(len(location) / 2) :]) + if len(location[int(len(location) / 2) :]) > 0 + else np.zeros(n_cells[f]) + ) + # Then concatenate the locations to get a [locations x cells for this frequency] matrix + cells[f] = np.stack(frequency, axis=0) + # Append the final average representations of this environment to the list of representations across environments + all_g.append(g) + all_p.append(p) + # Return list of locations x cells matrix of firing rates for each frequency module for each environment + return all_g, all_p + + +def generate_input(environment, walk): + """ + Generate model input from environment and walk. + Parameters + ---------- + environment : Environment + Environment from which to generate the model input. + walk : list + List of lists of lists, indicating for each step the location, observation, and action. + Returns + ------- + model_input : list + List of lists of lists, indicating for each step the location, observation, and action. + """ + # If no walk was provided: use the environment to generate one + if walk is None: + # Generate a single walk from environment with length depending on number of locations (so you're + # likely to visit each location) + walk = environment.generate_walks(environment.graph["n_locations"] * 100, 1)[0] + # Now this walk needs to be adjusted so that it looks like a batch with batch size 1 + for step in walk: + # Make single location into list + step[0] = [step[0]] + # Make single 1D observation vector into 2D row vector + step[1] = step[1].unsqueeze(dim=0) + # Make single action into list + step[2] = [step[2]] + return walk + + +def smooth(a, wsz): + """ + Smooth a 1D array with a window size. + Parameters + ---------- + a : list + 1D array to be smoothed. + wsz : int + Window size to use for smoothing. + Returns + ------- + out : list + Smoothed 1D array. + """ + out0 = np.convolve(a, np.ones(wsz, dtype=int), "valid") / wsz + r = np.arange(1, wsz - 1, 2) + start = np.cumsum(a[: wsz - 1])[::2] / r + stop = (np.cumsum(a[:-wsz:-1])[::2] / r)[::-1] + return np.concatenate((start, out0, stop)) diff --git a/neuralplayground/agents/whittington_2020_extras/whittington_2020_model.py b/neuralplayground/agents/whittington_2020_extras/whittington_2020_model.py new file mode 100644 index 00000000..f2f2d305 --- /dev/null +++ b/neuralplayground/agents/whittington_2020_extras/whittington_2020_model.py @@ -0,0 +1,1332 @@ +# Standard modules +import copy + +import numpy as np +import torch +from scipy.stats import truncnorm + +import neuralplayground.agents.whittington_2020_extras.whittington_2020_utils as utils + + +class Model(torch.nn.Module): + """ + Model class for TEM model. Inherits from torch.nn.Module, so it can be used as a torch module, and all parameters are + automatically registered as trainable parameters. + """ + + def __init__(self, params): + """ + Initialise model with parameters. + Parameters + ----------- + params: dict of parameters, usually generated from parameters() in parameters.py + """ + # First call super class init function to set up torch.nn.Module style model and inherit it's functionality + super(Model, self).__init__() + # Copy hyperparameters (e.g. network sizes) from parameter dict, usually generated from parameters() in parameters.py + self.hyper = copy.deepcopy(params) + # Create trainable parameters + self.init_trainable() + + def forward(self, walk, prev_iter=None, prev_M=None): + """ + Forward pass of TEM model. This consists of a transition, followed by an inference and a generative step, for + each step of the walk. + Parameters + ----------- + walk: list of [place, observation, action] tuples, where place is a list of locations, observation is a one-hot + vector of sensory observation, and action is a one-hot vector of action taken + prev_iter: list of Iteration objects, which contain all variables of the model for each step of the previous + walk. If None, all variables are initialised as zero. + prev_M: list of memory connectivity matrices for each frequency module. If None, all connectivity matrices are + initialised as zero. + Returns + ----------- + steps: list of Iteration objects, which contain all variables of the model for each step of the walk. + """ + # The previous iteration may contain walks without action. These are new walks, for which some parameters need to + # be reset. + steps = self.init_walks(prev_iter) + # Forward pass: perform a TEM iteration for each set of [place, observation, action], and produce inferred and + # generated variables for each step. + for g, x, a in walk: + # If there is no previous iteration at all: all walks are new, initialise a whole new iteration object + if steps is None: + # Use an Iteration object to set initial values before any real iterations, initialising M, x_inf as zero. + # Set actions to None blank to indicate there was no previous action + steps = [self.init_iteration(g, x, [None for _ in range(len(a))], prev_M)] + # Perform TEM iteration using transition from previous iteration + L, M, g_gen, p_gen, x_gen, x_logits, x_inf, g_inf, p_inf = self.iteration( + x, g, steps[-1].a, steps[-1].M, steps[-1].x_inf, steps[-1].g_inf + ) + # Store this iteration in iteration object in steps list + steps.append(Iteration(g, x, a, L, M, g_gen, p_gen, x_gen, x_logits, x_inf, g_inf, p_inf)) + # The first step is either a step from a previous walk or initialisiation rubbish, so remove it + steps = steps[1:] + + # Return steps, which is a list of Iteration objects + return steps + + def iteration(self, x, locations, a_prev, M_prev, x_prev, g_prev): + """ + Perform a single iteration of the TEM model. This consists of a transition step, followed by an inference step + and a generative step. + Parameters + ----------- + x: sensory observation, one-hot vector + locations: list of locations + a_prev: previous action, one-hot vector + M_prev: previous memory connectivity matrix + x_prev: previous sensory experience + g_prev: previous abstract location + Returns + ----------- + L: list of losses for this iteration + M: list of memory connectivity matrices for this iteration + gt_gen: list of generated abstract (grid cell) locations for this iteration + p_gen: list of generated place cell locations for this iteration + x_gen: list of generated observations for this iteration + x_logits: list of logits for generated observations for this iteration + x_inf: list of inferred observations for this iteration + g_inf: list of inferred abstract (grid cell) locations for this iteration + p_inf: list of inferred grounded (place cell) locations for this iteration + """ + # First, do the transition step, as it will be necessary for both the inference and generative part of the model + gt_gen, gt_inf = self.gen_g(a_prev, g_prev, locations) + # Run inference model: infer grounded location p_inf (hippocampus), abstract location g_inf (entorhinal). + # Also keep filtered sensory observation (x_inf), and retrieved grounded location p_inf_x + x_inf, g_inf, p_inf_x, p_inf = self.inference(x, locations, M_prev, x_prev, gt_inf) + # Run generative model: since generative model is only used for training purposes, it will generate from + # *inferred* variables instead of *generated* variables (as it would when used for generation) + x_gen, x_logits, p_gen = self.generative(M_prev, p_inf, g_inf, gt_gen) + # Update generative memory with generated and inferred grounded location. + M = [self.hebbian(M_prev[0], torch.cat(p_inf, dim=1), torch.cat(p_gen, dim=1))] + # If using memory for grounded location inference: append inference memory + if self.hyper[ + "use_p_inf" + ]: # Inference memory is identical to generative memory if using common memory, and updated separatedly if not + M.append( + M[0] + if self.hyper["common_memory"] + else self.hebbian( + M_prev[1], torch.cat(p_inf, dim=1), torch.cat(p_inf_x, dim=1), do_hierarchical_connections=False + ) + ) + # Calculate loss of this step + L = self.loss(gt_gen, p_gen, x_logits, x, g_inf, p_inf, p_inf_x) + # Return all iteration values + return L, M, gt_gen, p_gen, x_gen, x_logits, x_inf, g_inf, p_inf + + def inference(self, x, locations, M_prev, x_prev, g_gen): + """ + Perform inference step of the TEM model. + Parameters + ----------- + x: sensory observation, one-hot vector + locations: list of locations + M_prev: previous memory connectivity matrix + x_prev: previous sensory experience + g_gen: previous abstract location + Returns + ----------- + x_f: filtered sensory experience + g: inferred abstract (grid cell) locations + p_x: retrieved grounded (place cell) locations from memory by sensory experience + p: inferred grounded (place cell) locations + """ + # Compress sensory observation from one-hot to two-hot (or alternatively, whatever an MLP makes of it) + x_c = self.f_c(x) + # Temporally filter sensory observation by mixing it with previous experience + x_f = self.x_prev2x(x_prev, x_c) + # Prepare sensory experience for input to memory by normalisation and weighting + x_ = self.x2x_(x_f) + # Retrieve grounded location from memory by doing pattern completion on current sensory experience + p_x = ( + self.attractor(x_, M_prev[1], retrieve_it_mask=self.hyper["p_retrieve_mask_inf"]) + if self.hyper["use_p_inf"] + else None + ) + # Infer abstract location by combining previous abstract location and grounded location retrieved from + # memory by current sensory experience + g = self.inf_g(p_x, g_gen, x, locations) + # Prepare abstract location for input to memory by downsampling and weighting + g_ = self.g2g_(g) + # Infer grounded location from sensory experience and inferred abstract location + p = self.inf_p(x_, g_) + # Return variables in order that they were created + return x_f, g, p_x, p + + def generative(self, M_prev, p_inf, g_inf, g_gen): + """ + Perform generative step of the TEM model. + Parameters + ----------- + M_prev: previous memory connectivity matrix + p_inf: inferred grounded (place cell) locations + g_inf: inferred abstract (grid cell) locations + g_gen: previous abstract locations + Returns + ----------- + x_p: generated observation from inferred grounded location + x_g: generated observation from grounded location retrieved from inferred abstract + location + x_gt: generated observation from grounded location retrieved from abstract location + by transitioning + x_p_logits: logits for generated observation from inferred grounded location + x_g_logits: logits for generated observation from grounded location retrieved from + inferred abstract location + x_gt_logits: logits for generated observation from grounded location retrieved from + abstract location by transitioning + p_g_inf: grounded location retrieved from inferred abstract location + """ + # Generate observation from inferred grounded location, using only the highest frequency. + # Also keep non-softmaxed logits which are used in the loss later + x_p, x_p_logits = self.gen_x(p_inf[0]) + # Retrieve grounded location from memory by pattern completion on inferred abstract + # location + p_g_inf = self.gen_p(g_inf, M_prev[0]) # was p_mem_gen + # And generate observation from the grounded location retrieved from inferred abstract + # location + x_g, x_g_logits = self.gen_x(p_g_inf[0]) + # Retreive grounded location from memory by pattern completion on abstract location by + # transitioning + p_g_gen = self.gen_p(g_gen, M_prev[0]) + # Generate observation from sampled grounded location + x_gt, x_gt_logits = self.gen_x(p_g_gen[0]) + # Return all generated observations and their corresponding logits + return (x_p, x_g, x_gt), (x_p_logits, x_g_logits, x_gt_logits), p_g_inf + + def loss(self, g_gen, p_gen, x_logits, x, g_inf, p_inf, p_inf_x): + """ + Calculate losses of the current TEM iteration for each of the generated and inferred + variable. + Parameters + ----------- + g_gen: previous abstract (grid cell) locations + p_gen: previous grounded (place cell) locations + x_logits: logits for generated observation from grounded location retrieved from abstract + location by transitioning + x: sensory observation, one-hot vector + g_inf: inferred abstract (grid cell) locations + p_inf: inferred grounded (place cell) locations + p_inf_x: retrieved grounded (place cell) locations from memory by sensory experience + Returns + ----------- + L: list of losses for this iteration + """ + # Calculate loss function, separately for each component because you might want to reweight + # contributions later + # L_p_gen is squared error loss between inferred grounded location and grounded location retrieved + # from inferred abstract location + L_p_g = torch.sum(torch.stack(utils.squared_error(p_inf, p_gen), dim=0), dim=0) + # L_p_inf is squared error loss between inferred grounded location and grounded location retrieved + # from sensory experience + L_p_x = ( + torch.sum(torch.stack(utils.squared_error(p_inf, p_inf_x), dim=0), dim=0) + if self.hyper["use_p_inf"] + else torch.zeros_like(L_p_g) + ) + # L_g is squared error loss between generated abstract location and inferred abstract location + L_g = torch.sum(torch.stack(utils.squared_error(g_inf, g_gen), dim=0), dim=0) + # L_x is a cross-entropy loss between sensory experience and different model predictions. First get + # true labels from sensory experience + labels = torch.argmax(x, 1) + # L_x_gen: losses generated by generative model from g_prev -> g -> p -> x + L_x_gen = utils.cross_entropy(x_logits[2], labels) + # L_x_g: Losses generated by generative model from g_inf -> p -> x + L_x_g = utils.cross_entropy(x_logits[1], labels) + # L_x_p: Losses generated by generative model from p_inf -> x + L_x_p = utils.cross_entropy(x_logits[0], labels) + # L_reg are regularisation losses, L_reg_g on L2 norm of g + L_reg_g = torch.sum(torch.stack([torch.sum(g**2, dim=1) for g in g_inf], dim=0), dim=0) + # And L_reg_p regularisation on L1 norm of p + L_reg_p = torch.sum(torch.stack([torch.sum(torch.abs(p), dim=1) for p in p_inf], dim=0), dim=0) + # Return total loss as list of losses, so you can possibly reweight them + L = [L_p_g, L_p_x, L_x_gen, L_x_g, L_x_p, L_g, L_reg_g, L_reg_p] + return L + + def init_trainable(self): + """ + Initialise all trainable parameters of the TEM model. This is done in a separate function so that + it can be called again after loading a model from file. + """ + # Scale factor in Laplacian transform for each frequency module. High frequency comes first, low + # frequency comes last. Learn inverse sigmoid instead of scale factor directly, so domain of alpha is -inf, inf + self.alpha = torch.nn.ParameterList( + [ + torch.nn.Parameter( + torch.tensor(np.log(self.hyper["f_initial"][f] / (1 - self.hyper["f_initial"][f])), dtype=torch.float) + ) + for f in range(self.hyper["n_f"]) + ] + ) + # Entorhinal preference weights + self.w_x = torch.nn.Parameter(torch.tensor(1.0)) + # Entorhinal preference bias + self.b_x = torch.nn.Parameter(torch.zeros(self.hyper["n_x_c"])) + # Frequency module specific scaling of sensory experience before input to hippocampus + self.w_p = torch.nn.ParameterList([torch.nn.Parameter(torch.tensor(1.0)) for f in range(self.hyper["n_f"])]) + # Initial activity of abstract location cells when entering a new environment, like a prior on g. + # Initialise with truncated normal + self.g_init = torch.nn.ParameterList( + [ + torch.nn.Parameter( + torch.tensor( + truncnorm.rvs(-2, 2, size=self.hyper["n_g"][f], loc=0, scale=self.hyper["g_init_std"]), + dtype=torch.float, + ) + ) + for f in range(self.hyper["n_f"]) + ] + ) + # Log of standard deviation of abstract location cells when entering a new environment; standard + # deviation of the prior on g. Initialise with truncated normal + self.logsig_g_init = torch.nn.ParameterList( + [ + torch.nn.Parameter( + torch.tensor( + truncnorm.rvs(-2, 2, size=self.hyper["n_g"][f], loc=0, scale=self.hyper["g_init_std"]), + dtype=torch.float, + ) + ) + for f in range(self.hyper["n_f"]) + ] + ) + # MLP for transition weights (not in paper, but recommended by James so you can learn about + # similarities between actions). Size is given by grid connections + self.MLP_D_a = MLP( + [self.hyper["n_actions"] for _ in range(self.hyper["n_f"])], + [ + sum( + [ + self.hyper["n_g"][f_from] + for f_from in range(self.hyper["n_f"]) + if self.hyper["g_connections"][f_to][f_from] + ] + ) + * self.hyper["n_g"][f_to] + for f_to in range(self.hyper["n_f"]) + ], + activation=[torch.tanh, None], + hidden_dim=[self.hyper["d_hidden_dim"] for _ in range(self.hyper["n_f"])], + bias=[True, False], + ) + # Initialise the hidden to output weights as zero, so initially you simply keep the current abstract + # location to predict the next abstract location + self.MLP_D_a.set_weights(1, 0.0) + # Transition weights without specifying an action for use in generative model with shiny objects + self.D_no_a = torch.nn.ParameterList( + [ + torch.nn.Parameter( + torch.zeros( + sum( + [ + self.hyper["n_g"][f_from] + for f_from in range(self.hyper["n_f"]) + if self.hyper["g_connections"][f_to][f_from] + ] + ) + * self.hyper["n_g"][f_to] + ) + ) + for f_to in range(self.hyper["n_f"]) + ] + ) + # MLP for standard deviation of transition sample + self.MLP_sigma_g_path = MLP( + self.hyper["n_g"], + self.hyper["n_g"], + activation=[torch.tanh, torch.exp], + hidden_dim=[2 * g for g in self.hyper["n_g"]], + ) + # MLP for standard devation of grounded location from retrieved memory sample + self.MLP_sigma_p = MLP(self.hyper["n_p"], self.hyper["n_p"], activation=[torch.tanh, torch.exp]) + # MLP to generate mean of abstract location from downsampled abstract location, obtained by summing + # grounded location over sensory preferences in inference model + self.MLP_mu_g_mem = MLP(self.hyper["n_g_subsampled"], self.hyper["n_g"], hidden_dim=[2 * g for g in self.hyper["n_g"]]) + # Initialise weights in last layer of MLP_mu_g_mem as truncated normal for each frequency module + self.MLP_mu_g_mem.set_weights( + -1, + [ + torch.tensor( + truncnorm.rvs( + -2, 2, size=list(self.MLP_mu_g_mem.w[f][-1].weight.shape), loc=0, scale=self.hyper["g_mem_std"] + ), + dtype=torch.float, + ) + for f in range(self.hyper["n_f"]) + ], + ) + # MLP to generate standard deviation of abstract location from two measures (generated observation + # error and inferred abstract location vector norm) of memory quality + self.MLP_sigma_g_mem = MLP( + [2 for _ in self.hyper["n_g_subsampled"]], + self.hyper["n_g"], + activation=[torch.tanh, torch.exp], + hidden_dim=[2 * g for g in self.hyper["n_g"]], + ) + # MLP to generate mean of abstract location directly from shiny object presence. Outputs to object + # vector cell modules if they're separated, else to all abstract location modules + self.MLP_mu_g_shiny = MLP( + [1 for _ in range(self.hyper["n_f_ovc"] if self.hyper["separate_ovc"] else self.hyper["n_f"])], + [n_g for n_g in self.hyper["n_g"][(self.hyper["n_f_g"] if self.hyper["separate_ovc"] else 0) :]], + hidden_dim=[2 * n_g for n_g in self.hyper["n_g"][(self.hyper["n_f_g"] if self.hyper["separate_ovc"] else 0) :]], + ) + # MLP to generate standard deviation of abstract location directly from shiny object presence. + # Outputs to object vector cell modules if they're separated, else to all abstract location modules + self.MLP_sigma_g_shiny = MLP( + [1 for _ in range(self.hyper["n_f_ovc"] if self.hyper["separate_ovc"] else self.hyper["n_f"])], + [n_g for n_g in self.hyper["n_g"][(self.hyper["n_f_g"] if self.hyper["separate_ovc"] else 0) :]], + hidden_dim=[2 * n_g for n_g in self.hyper["n_g"][(self.hyper["n_f_g"] if self.hyper["separate_ovc"] else 0) :]], + activation=[torch.tanh, torch.exp], + ) + # MLP for decompressing highest frequency sensory experience to sensory observation + self.MLP_c_star = MLP(self.hyper["n_x_f"][0], self.hyper["n_x"], hidden_dim=20 * self.hyper["n_x_c"]) + + def init_iteration(self, g, x, a, M): + """ + Initialise a new iteration of the TEM model. + Parameters + ----------- + g: previous abstract (grid cell) locations + x: sensory observation, one-hot vector + a: previous action, one-hot vector + M: previous memory connectivity matrix + Returns + ----------- + iteration: Iteration object, which contains all variables of the model for this iteration. + """ + # On the very first iteration, update the batch size based on the data. This is useful when + # doing analysis on the network with different batch sizes compared to training + self.hyper["batch_size"] = x.shape[0] + # Initalise hebbian memory connectivity matrix [M_gen, M_inf] if it wasn't initialised yet + if M is None: + # Create new empty memory dict for generative network: zero connectivity matrix M_0, then empty + # list of the memory vectors a and b for each iteration for efficient hebbian memory computation + M = [torch.zeros((self.hyper["batch_size"], sum(self.hyper["n_p"]), sum(self.hyper["n_p"])), dtype=torch.float)] + # Append inference memory only if memory is used in grounded location inference + if self.hyper["use_p_inf"]: + # If inference and generative network share common memory: reuse same connectivity, and + # same memory vectors. Else, create a new empty memory list for inference network + M.append( + M[0] + if self.hyper["common_memory"] + else torch.zeros( + (self.hyper["batch_size"], sum(self.hyper["n_p"]), sum(self.hyper["n_p"])), dtype=torch.float + ) + ) + # Initialise previous abstract location by stacking abstract location prior + g_inf = [torch.stack([self.g_init[f] for _ in range(self.hyper["batch_size"])]) for f in range(self.hyper["n_f"])] + # Initialise previous sensory experience with zeros, as there is no data yet for + # temporal smoothing + x_inf = [torch.zeros((self.hyper["batch_size"], self.hyper["n_x_f"][f])) for f in range(self.hyper["n_f"])] + # And construct new iteration for that g, x, a, and M + return Iteration(g=g, x=x, a=a, M=M, x_inf=x_inf, g_inf=g_inf) + + def init_walks(self, prev_iter): + """ + Initialise a new walk of the TEM model. + Parameters + ----------- + prev_iter: list of Iteration objects, which contain all variables of the model for each step of the + previous walk. If None, all variables are initialised as zero. + Returns + ----------- + prev_iter: list of Iteration objects, which contain all variables of the model for each step of the + previous walk. If None, all variables are initialised as zero. + """ + # Only reset parameters for previous iteration if a previous iteration was actually provided - if + # it wasn't, all parameters will be reset when creating a fresh Iteration object in init_iteration + if prev_iter is not None: + # The supplied previous iteration might have new walks starting, with empty actions. For these + # walks some parameters need to be reset + for a_i, a in enumerate(prev_iter[0].a): + # A new walk is indicated by having a None action in the previous iteration + if a is None: + # Reset the initial connectivity matrix for this walk + for M in prev_iter[0].M: + M[a_i, :, :] = 0 + # Reset the abstract location for this walk + for f, g_inf in enumerate(prev_iter[0].g_inf): + g_inf[a_i, :] = self.g_init[f] + # Reset the sensory experience for this walk + for f, x_inf in enumerate(prev_iter[0].x_inf): + x_inf[a_i, :] = torch.zeros(self.hyper["n_x_f"][f]) + # Return the iteration with reset parameters (or simply the empty array if prev_iter was empty) + return prev_iter + + def gen_g(self, a_prev, g_prev, locations): + """ + Perform transition step of the TEM model. This consists of a transition from previous abstract location + to new abstract location using weights specific to + action taken for each frequency module. + Parameters + ----------- + a_prev: previous action, one-hot vector + g_prev: previous abstract (grid cell) locations + locations: list of locations + Returns + ----------- + g_gen: generated abstract (grid cell) locations + g: abstract (grid cell) locations + """ + # Transition from previous abstract location to new abstract location using weights specific to action + # taken for each frequency module + mu_g = self.f_mu_g_path(a_prev, g_prev) + sigma_g = self.f_sigma_g_path(a_prev, g_prev) + # Either sample new abstract location g or simply take the mean of distribution in noiseless case. + g = [ + mu_g[f] + sigma_g[f] * np.random.randn() if self.hyper["do_sample"] else mu_g[f] for f in range(self.hyper["n_f"]) + ] + # But for environments with shiny objects, the transition to the new abstract location shouldn't have + # access to the action direction in the generative model + shiny_envs = [location["shiny"] is not None for location in locations] + # If there are any shiny environments, the abstract locations for the generative model will need to be + # re-calculated without providing actions for those + g_gen = self.f_mu_g_path(a_prev, g_prev, no_direc=shiny_envs) if any(shiny_envs) else g + # Return generated abstract location after transition + return g_gen, (g, sigma_g) + + def gen_p(self, g, M_prev): + """ + Generate grounded location from abstract location by pattern completion on abstract location. This is + used in the generative model to generate observations from grounded location. + Parameters + ----------- + g: abstract (grid cell) locations + M_prev: previous memory connectivity matrix + Returns + ----------- + p: generated grounded (place cell) locations + """ + # We want to use g as an index for memory retrieval, but it doesn't have the right dimensions (these + # are grid cells, we need place cells). We need g_ instead + g_ = self.g2g_(g) + # Retreive memory: do pattern completion on abstract location to get grounded location + mu_p = self.attractor(g_, M_prev, retrieve_it_mask=self.hyper["p_retrieve_mask_gen"]) + sigma_p = self.f_sigma_p(mu_p) + # Either sample new grounded location p or simply take the mean of distribution in noiseless case + p = [ + mu_p[f] + sigma_p[f] * np.random.randn() if self.hyper["do_sample"] else mu_p[f] for f in range(self.hyper["n_f"]) + ] + # Return pattern-completed grounded location p after memory retrieval + return p + + def gen_x(self, p): + """ + Generate observation from grounded location by MLP. This is used in the generative model to generate + observations from grounded location. + Parameters + ----------- + p: grounded (place cell) locations + Returns + ----------- + x: generated observation + logits: logits for generated observation + """ + # Get categorical distribution over observations from grounded location + # If you actually want to sample observation, you need a reparaterisation trick for categorical distributions + # Sampling would be the correct way to do this, since observations are discrete, and it's also what the + # TEM paper says + # However, it looks like you could also get away with using categorical distribution directly as an + # approximation of the one-hot observations + if self.hyper["do_sample"]: + x, logits = self.f_x(p) # This is a placeholder! Should be done using reparameterisation trick (like + # https://blog.evjang.com/2016/11/tutorial-categorical-variational.html) + else: + x, logits = self.f_x(p) + # Return one-hot (or almost one-hot...) observation obtained from grounded location, and also + # the non-softmaxed logits + return x, logits + + def inf_g(self, p_x, g_gen, x, locations): + """ + Infer abstract location from the combination of [grounded location retrieved from memory by sensory + experience] and [previous abstract location and action (path integration)] + Parameters + ----------- + p_x: retrieved grounded (place cell) locations from memory by sensory experience + g_gen: previous abstract (grid cell) locations + x: sensory observation, one-hot vector + locations: list of locations + Returns + ----------- + mu_g: inferred abstract (grid cell) locations + sigma_g: standard deviation of inferred abstract (grid cell) locations + """ + # Infer abstract location from the combination of [grounded location retrieved from memory by sensory + # experience] ... + if self.hyper["use_p_inf"]: + # Not in paper, but makes sense from symmetry with f_x: first get g from p by "summing over sensory + # preferences" g = p * W_repeat^T + g_downsampled = [torch.matmul(p_x[f], torch.t(self.hyper["W_repeat"][f])) for f in range(self.hyper["n_f"])] + # Then use abstract location after summing over sensory preferences as input to MLP to obtain the + # inferred abstract location from memory + mu_g_mem = self.f_mu_g_mem(g_downsampled) + # Not in paper, but this greatly improves zero-shot inference: provide the uncertainty function of the + # inferred abstract location with measures of memory quality + with torch.no_grad(): + # For the first measure, use the grounded location inferred from memory to generate an observation + x_hat, x_hat_logits = self.gen_x(p_x[0]) + # Then calculate the error between the generated observation and the actual observation: if the memory + # is working well, this error should be small + err = utils.squared_error(x, x_hat) + # The second measure is the vector norm of the inferred abstract location; good memories should have similar + # vector norms. Concatenate the two measures as input for the abstract location uncertainty function + sigma_g_input = [ + torch.cat((torch.sum(g**2, dim=1, keepdim=True), torch.unsqueeze(err, dim=1)), dim=1) for g in mu_g_mem + ] + # Not in paper, but recommended by James for stability: get final mean of inferred abstract location by + # clamping activations between -1 and 1 + mu_g_mem = self.f_g_clamp(mu_g_mem) + # And get standard deviation/uncertainty of inferred abstract location by providing uncertainty function with + # memory quality measures + sigma_g_mem = self.f_sigma_g_mem(sigma_g_input) + # ... and [previous abstract location and action (path integration)] + mu_g_path = g_gen[0] + sigma_g_path = g_gen[1] + # Infer abstract location by combining previous abstract location and grounded location retrieved from memory by + # current sensory experience + mu_g, sigma_g = [], [] + for f in range(self.hyper["n_f"]): + if self.hyper["use_p_inf"]: + # Then get full gaussian distribution of inferred abstract location by calculating precision weighted mean + mu, sigma = utils.inv_var_weight([mu_g_path[f], mu_g_mem[f]], [sigma_g_path[f], sigma_g_mem[f]]) + else: + # Or simply completely ignore the inference memory here, to test if things are working + mu, sigma = mu_g_path[f], sigma_g_path[f] + # Append mu and sigma to list for all frequency modules + mu_g.append(mu) + sigma_g.append(sigma) + # Finally (though not in paper), also add object vector cell information to inferred abstract location for + # environments with shiny objects + shiny_envs = [location["shiny"] is not None for location in locations] + if any(shiny_envs): + # Find for which environments the current location has a shiny object + shiny_locations = torch.unsqueeze( + torch.stack( + [ + torch.tensor(location["shiny"], dtype=torch.float) + for location in locations + if location["shiny"] is not None + ] + ), + dim=-1, + ) + # Get abstract location for environments with shiny objects and feed to each of the object vector cell modules + mu_g_shiny = self.f_mu_g_shiny( + [shiny_locations for _ in range(self.hyper["n_f_g"] if self.hyper["separate_ovc"] else self.hyper["n_f"])] + ) + sigma_g_shiny = self.f_sigma_g_shiny( + [shiny_locations for _ in range(self.hyper["n_f_g"] if self.hyper["separate_ovc"] else self.hyper["n_f"])] + ) + # Update only object vector modules with shiny-inferred abstract location: start from offset if object vector + # modules are separate + module_start = self.hyper["n_f_g"] if self.hyper["separate_ovc"] else 0 + # Inverse variance weighting is associative, so I can just do additional inverse variance weighting to the + # previously obtained mu and sigma - but only for object vector cell modules! + for f in range(module_start, self.hyper["n_f"]): + # Add inferred abstract location from shiny objects to previously obtained position, only for environments + # with shiny objects + mu, sigma = utils.inv_var_weight( + [mu_g[f][shiny_envs, :], mu_g_shiny[f - module_start]], + [sigma_g[f][shiny_envs, :], sigma_g_shiny[f - module_start]], + ) + # In order to update only the environments with shiny objects, without in-place value assignment, + # construct a mask of shiny environments + mask = torch.zeros_like(mu_g[f], dtype=torch.bool) + mask[shiny_envs, :] = True + # Use mask to update the shiny environment entries in inferred abstract locations + mu_g[f] = mu_g[f].masked_scatter(mask, mu) + sigma_g[f] = sigma_g[f].masked_scatter(mask, sigma) + # Either sample inferred abstract location from combined (precision weighted) distribution or just take + # mean + g = [ + mu_g[f] + sigma_g[f] * np.random.randn() if self.hyper["do_sample"] else mu_g[f] for f in range(self.hyper["n_f"]) + ] + # Return abstract location inferred from grounded location from memory and previous abstract location + return g + + def inf_p(self, x_, g_): + """ + Infer grounded location from sensory experience and inferred abstract location for each module + Parameters + ----------- + x_: sensory input to memory + g_: abstract (grid cell) locations + Returns + ----------- + p: inferred grounded (place cell) locations + """ + # Infer grounded location from sensory experience and inferred abstract location for each module + p = [] + # Use the same transformation for each frequency module: leaky relu for sparsity + for f in range(self.hyper["n_f"]): + mu_p = self.f_p(g_[f] * x_[f]) # This is element-wise multiplication + sigma_p = 0 # Unclear from paper (typo?). Some undefined function f that takes two arguments: f(f_n(x),g) + # Either sample inferred grounded location or just take mean + if self.hyper["do_sample"]: + p.append(mu_p + sigma_p * np.random.randn()) + else: + p.append(mu_p) + # Return new memory constructed from sensory experience and inferred abstract location + return p + + def x_prev2x(self, x_prev, x_c): + """ + Perform temporal filtering on sensory experience for each frequency module + Parameters + ----------- + x_prev: previous sensory experience + x_c: current sensory experience + Returns + ----------- + x: filtered sensory experience + """ + # Calculate factor for filtering from sigmoid of learned parameter + alpha = [torch.nn.Sigmoid()(self.alpha[f]) for f in range(self.hyper["n_f"])] + # Do exponential temporal filtering for each frequency modulemod + x = [(1 - alpha[f]) * x_prev[f] + alpha[f] * x_c for f in range(self.hyper["n_f"])] + return x + + def x2x_(self, x): + """ + Prepare sensory input for input to memory by weighting and normalisation for each frequency module + Parameters + ----------- + x: sensory observation, one-hot vector + Returns + ----------- + x_: weighted and normalised sensory input to memory + """ + # Prepare sensory input for input to memory by weighting and normalisation for each frequency module + # Get normalised sensory input for each frequency module + normalised = self.f_n(x) + # Then reshape and reweight (use sigmoid to keep weight between 0 and 1) each frequency module separately: + # matrix multiplication by W_tile prepares x for outer product with g by element-wise multiplication + x_ = [ + torch.nn.Sigmoid()(self.w_p[f]) * torch.matmul(normalised[f], self.hyper["W_tile"][f]) + for f in range(self.hyper["n_f"]) + ] + return x_ + + def g2g_(self, g): + """ + Prepare abstract location for input to memory by reshaping and down-sampling for each frequency module + Parameters + ----------- + g: abstract (grid cell) locations + Returns + ----------- + g_: downsampled abstract (grid cell) locations + """ + # Prepares abstract location for input to memory by reshaping and down-sampling for each frequency module + # Get downsampled abstract location for each frequency module + downsampled = self.f_g(g) + # Then reshape and reweight each frequency module separately + g_ = [torch.matmul(downsampled[f], self.hyper["W_repeat"][f]) for f in range(self.hyper["n_f"])] + return g_ + + def f_mu_g_path(self, a_prev, g_prev, no_direc=None): + """ + Perform transition step of the TEM model. This consists of a transition from previous abstract location + to new abstract location using weights specific to + action taken for each frequency module. + Parameters + ----------- + a_prev: previous action, one-hot vector + g_prev: previous abstract (grid cell) locations + no_direc: list of booleans indicating for which walks the transition direction should be omitted + (e.g. no shiny objects, or in inference model: set to all false) + Returns + ----------- + g_step: new abstract (grid cell) locations + """ + # If there are no environments where the transition direction needs to be omitted (e.g. no shiny objects, + # or in inference model: set to all false + no_direc = [False for _ in a_prev] if no_direc is None else no_direc + # Remove all Nones from a_prev: these are walks where there was no previous action, so no step needs to + # be calculated for those + a_prev_step = [a if a is not None else 0 for a in a_prev] + # And also keep track of which walks these valid step actions are for + a_do_step = [a is not None for a in a_prev] + # Transform list of actions into batch of one-hot row vectors. + if self.hyper["has_static_action"]: + # If this world has static actions: whenever action 0 (standing still) appears, the action vector + # should be all zeros. All other actions should have a 1 in the label-1 entry + a = torch.zeros((len(a_prev_step), self.hyper["n_actions"])).scatter_( + 1, + torch.clamp(torch.tensor(a_prev_step).unsqueeze(1) - 1, min=0), + 1.0 * (torch.tensor(a_prev_step).unsqueeze(1) > 0), + ) + else: + # Without static actions: each action label should become a one-hot vector for that label + a = torch.zeros((len(a_prev_step), self.hyper["n_actions"])).scatter_( + 1, torch.tensor(a_prev_step).unsqueeze(1), 1.0 + ) + # Get vector of transition weights by feeding actions into MLP + D_a = self.MLP_D_a([a for _ in range(self.hyper["n_f"])]) + # Replace transition weights by non-directional transition weights in environments where transition direction + # needs to be omitted (can set only if any no_direc) + for f in range(self.hyper["n_f"]): + D_a[f][no_direc, :] = self.D_no_a[f] + # Reshape transition weight vector into transition matrix. The number of rows in the transition matrix is given + # by the incoming abstract location connections for each frequency module + D_a = [ + torch.reshape( + D_a[f_to], + ( + -1, + sum( + [ + self.hyper["n_g"][f_from] + for f_from in range(self.hyper["n_f"]) + if self.hyper["g_connections"][f_to][f_from] + ] + ), + self.hyper["n_g"][f_to], + ), + ) + for f_to in range(self.hyper["n_f"]) + ] + # Select the frequency modules of the previous abstract location that are connected to each frequency module, to + g_in = [ + torch.unsqueeze( + torch.cat( + [g_prev[f_from] for f_from in range(self.hyper["n_f"]) if self.hyper["g_connections"][f_to][f_from]], dim=1 + ), + 1, + ) + for f_to in range(self.hyper["n_f"]) + ] + # Reshape transition weight vector into transition matrix. The number of rows in the transition matrix is given by the + # incoming abstract location connections for each frequency module + delta = [torch.squeeze(torch.matmul(g, T)) for g, T in zip(g_in, D_a)] + # Not in the paper, but recommended by James for stability: use inferred code as *difference* in abstract location. + # Calculate new abstract location from previous abstract location and difference + g_step = [g + d if g.dim() > 1 else torch.unsqueeze(g + d, 0) for g, d in zip(g_prev, delta)] + # Not in paper, but recommended by James for stability: clamp activations between -1 and 1 + g_step = self.f_g_clamp(g_step) + # Build new abstract location from result of transition if there was one, or from prior on abstract location + # if there wasn't + return [ + torch.stack([g_step[f][batch_i, :] if do_step else self.g_init[f] for batch_i, do_step in enumerate(a_do_step)]) + for f in range(self.hyper["n_f"]) + ] + + def f_sigma_g_path(self, a_prev, g_prev): + """ + Use multi layer perceptron to generate standard deviation from all previous abstract locations, including those + that were just initialised and not real previous locations. + Parameters + ----------- + a_prev: previous action, one-hot vector + g_prev: previous abstract (grid cell) locations + Returns + ----------- + sigma_g: standard deviation of new abstract (grid cell) locations + """ + # Keep track of which walks these valid step actions are for + a_do_step = [a is not None for a in a_prev] + from_g = self.MLP_sigma_g_path(g_prev) + # And take exponent to get prior sigma for the walks that didn't have a previous location + from_prior = [torch.exp(logsig) for logsig in self.logsig_g_init] + # Now select the standard deviation generated from the previous abstract location if there was one, and the prior + # standard deviation on abstract location otherwise + return [ + torch.stack([from_g[f][batch_i, :] if do_step else from_prior[f] for batch_i, do_step in enumerate(a_do_step)]) + for f in range(self.hyper["n_f"]) + ] + + def f_mu_g_mem(self, g_downsampled): + """ + Use multi layer perceptron to generate mean of abstract location from down-sampled abstract location, obtained by + summing over sensory dimension of grounded location. + Parameters + ----------- + g_downsampled: downsampled abstract (grid cell) locations + Returns + ----------- + mu_g_mem: inferred abstract (grid cell) locations + """ + return self.MLP_mu_g_mem(g_downsampled) + + def f_sigma_g_mem(self, g_downsampled): + """ + Use multi layer perceptron to generate standard deviation of abstract location from down-sampled abstract location, + obtained by summing over sensory dimension of grounded location. + Parameters + ----------- + g_downsampled: downsampled abstract (grid cell) locations + Returns + ----------- + sigma_g_mem: standard deviation of inferred abstract (grid cell) locations + """ + sigma = self.MLP_sigma_g_mem(g_downsampled) + # Not in paper, but also offset this sigma over training, so you can reduce influence of inferred p early on + return [sigma[f] + self.hyper["p2g_scale_offset"] * self.hyper["p2g_sig_val"] for f in range(self.hyper["n_f"])] + + def f_mu_g_shiny(self, shiny): + """ + Use multi layer perceptron to generate mean of abstract location from boolean location shiny-ness. Outputs to + object vector cell modules if they're separated, + else to all abstract location modules + Parameters + ----------- + shiny: boolean location shiny-ness + Returns + ----------- + mu_g: inferred abstract (grid cell) locations + """ + mu_g = self.MLP_mu_g_shiny(shiny) + # Take absolute because James wants object vector cells to be positive + mu_g = [torch.abs(mu) for mu in mu_g] + # Then apply clamp and leaky relu to get object vector module activations, like it's done for ground location + # activations + g = self.f_p(mu_g) + return g + + def f_sigma_g_shiny(self, shiny): + """ + Use multi layer perceptron to generate standard deviation of abstract location from boolean location shiny-ness. + Outputs to object vector cell modules if they're separated, + else to all abstract location modules + Parameters + ----------- + shiny: boolean location shiny-ness + Returns + ----------- + sigma_g: standard deviation of inferred abstract (grid cell) locations + """ + return self.MLP_sigma_g_shiny(shiny) + + def f_sigma_p(self, p): + """ + Use multi layer perceptron to generate standard deviation of grounded location retrieval. + Parameters + ----------- + p: grounded (place cell) locations + Returns + ----------- + sigma_p: standard deviation of grounded (place cell) locations + """ + return self.MLP_sigma_p(p) + + def f_x(self, p): + """ + Use multi layer perceptron to generate observation from grounded location. Do this by first calculating + categorical probability distribution over observations for a given + ground location, then sampling from that distribution. + Parameters + ----------- + p: grounded (place cell) locations + Returns + ----------- + x: generated observation + logits: logits for generated observation + """ + # Calculate categorical probability distribution over observations for a given ground location + # p is the flattened (by concatenating rows - like reading sentences) outer product of g and x (p = g^T * x). + # Therefore to get the sensory experience x for a grounded location p, sum over all abstract locations + # g for each component of x. That's what the paper means when it says "sum over entorhinal preferences". + # It can be done with the transpose of W_tile + x = self.w_x * torch.matmul(p, torch.t(self.hyper["W_tile"][0])) + self.b_x + # Then we need to decompress the temporally filtered sensory experience into a single current experience prediction + logits = self.f_c_star(x) + # We'll keep both the logits (domain -inf, inf) and probabilities (domain 0, 1) because both are needed later on + probability = utils.softmax(logits) + return probability, logits + + def f_c_star(self, compressed): + """ + Use multi layer perceptron to decompress sensory experience at highest frequency + Parameters + ----------- + compressed: compressed sensory experience + Returns + ----------- + decompressed: decompressed sensory experience + """ + return self.MLP_c_star(compressed) + + def f_c(self, decompressed): + """ + Use multi layer perceptron to compress sensory experience at highest frequency + Parameters + ----------- + decompressed: decompressed sensory experience + Returns + ----------- + compressed: compressed sensory experience + """ + # Compress sensory observation from one-hot provided by world to two-hot for ease of computation + return torch.stack([self.hyper["two_hot_table"][i] for i in torch.argmax(decompressed, dim=1)], dim=0) + + def f_n(self, x): + """ + Normalise sensory observation for each frequency module. + Parameters + ----------- + x: sensory observation, one-hot vector + Returns + ----------- + normalised: normalised sensory observation + """ + normalised = [utils.normalise(utils.relu(x[f] - torch.mean(x[f]))) for f in range(self.hyper["n_f"])] + return normalised + + def f_g(self, g): + """ + Downsample abstract location for each frequency module. + Parameters + ----------- + g: abstract (grid cell) locations + Returns + ----------- + downsampled: downsampled abstract (grid cell) locations + """ + downsampled = [torch.matmul(g[f], self.hyper["g_downsample"][f]) for f in range(self.hyper["n_f"])] + return downsampled + + def f_g_clamp(self, g): + """ + Calculate activation for abstract location, thresholding between -1 and 1. + Parameters + ----------- + g: abstract (grid cell) locations + Returns + ----------- + activation: activation for abstract (grid cell) locations + """ + activation = [torch.clamp(g_f, min=-1, max=1) for g_f in g] + return activation + + def f_p(self, p): + """ + Calculate activation for inferred grounded location, using a leaky relu for sparsity. Either apply to full + multi-frequency grounded location or single frequency module. + Parameters + ----------- + p: grounded (place cell) locations + Returns + ----------- + activation: activation for grounded (place cell) locations + """ + activation = ( + [utils.leaky_relu(torch.clamp(p_f, min=-1, max=1)) for p_f in p] + if type(p) is list + else utils.leaky_relu(torch.clamp(p, min=-1, max=1)) + ) + return activation + + def attractor(self, p_query, M, retrieve_it_mask=None): + """ + Retrieve grounded location from attractor network memory with weights M by pattern-completing query. For example, + initial attractor input can come from abstract location + (g_) or sensory experience (x_). + Parameters + ----------- + p_query: initial attractor input + M: memory connectivity matrix + retrieve_it_mask: mask for hierarchical retrieval + Returns + ----------- + p: retrieved grounded (place cell) locations + """ + # Start by flattening query grounded locations across frequency modules + h_t = torch.cat(p_query, dim=1) + # Apply activation function to initial memory index + h_t = self.f_p(h_t) + # Hierarchical retrieval (not in paper) is implemented by early stopping retrieval for low frequencies, + # using a mask. If not specified: initialise mask as all 1s + retrieve_it_mask = ( + [torch.ones(sum(self.hyper["n_p"])) for _ in range(self.hyper["n_p"])] + if retrieve_it_mask is None + else retrieve_it_mask + ) + # Iterate attractor dynamics to do pattern completion + for tau in range(self.hyper["i_attractor"]): + # Apply one iteration of attractor dynamics, but only where there is a 1 in the mask. NB retrieve_it_mask + # entries have only one row, but are broadcasted to batch_size + h_t = (1 - retrieve_it_mask[tau]) * h_t + retrieve_it_mask[tau] * ( + self.f_p(self.hyper["kappa"] * h_t + torch.squeeze(torch.matmul(torch.unsqueeze(h_t, 1), M))) + ) + # Make helper list of cumulative neurons per frequency module for grounded locations + n_p = np.cumsum(np.concatenate(([0], self.hyper["n_p"]))) + # Now re-cast the grounded location into different frequency modules, since memory retrieval turned it + # into one long vector + p = [h_t[:, n_p[f] : n_p[f + 1]] for f in range(self.hyper["n_f"])] + return p + + def hebbian(self, M_prev, p_inferred, p_generated, do_hierarchical_connections=True): + """ + Update attractor network memory by Hebbian learning of pattern. For example, initial attractor input can + come from abstract location (g_) or sensory experience (x_). + Parameters + ----------- + M_prev: previous memory connectivity matrix + p_inferred: inferred grounded (place cell) locations + p_generated: generated grounded (place cell) locations + do_hierarchical_connections: whether to use hierarchical connections + Returns + ----------- + M: updated memory connectivity matrix + """ + # Create new ground memory for attractor network by setting weights to outer product of learned vectors + # p_inferred corresponds to p in the paper, and p_generated corresponds to p^. + # The order of p + p^ and p - p^ is reversed since these are row vectors, instead of column vectors in the paper. + M_new = torch.squeeze( + torch.matmul(torch.unsqueeze(p_inferred + p_generated, 2), torch.unsqueeze(p_inferred - p_generated, 1)) + ) + # Multiply by connection vector, e.g. only keeping weights from low to high frequencies for hierarchical retrieval + if do_hierarchical_connections: + M_new = M_new * self.hyper["p_update_mask"] + # Store grounded location in attractor network memory with weights M by Hebbian learning of pattern + M = torch.clamp(self.hyper["lambda"] * M_prev + self.hyper["eta"] * M_new, min=-1, max=1) + return M + + +class MLP(torch.nn.Module): + """ + Class for multi layer perceptron with multiple modules. Each module has two layers: input->hidden and hidden->output. + """ + + def __init__(self, in_dim, out_dim, activation=(torch.nn.functional.elu, None), hidden_dim=None, bias=(True, True)): + """ + Initialise multi layer perceptron with multiple modules. + """ + # First call super class init function to set up torch.nn.Module style model and inherit it's functionality + super(MLP, self).__init__() + # Check if this network consists of module: are input and output dimensions lists? If not, make them + # (but remember it wasn't) + if type(in_dim) is list: + self.is_list = True + else: + in_dim = [in_dim] + out_dim = [out_dim] + self.is_list = False + # Find number of modules + self.N = len(in_dim) + # Create weights (input->hidden, hidden->output) for each module + self.w = torch.nn.ModuleList([]) + for n in range(self.N): + # If number of hidden dimensions is not specified: mean of input and output + if hidden_dim is None: + hidden = int(np.mean([in_dim[n], out_dim[n]])) + else: + hidden = hidden_dim[n] if self.is_list else hidden_dim + # Each module has two sets of weights: input->hidden and hidden->output + self.w.append( + torch.nn.ModuleList( + [torch.nn.Linear(in_dim[n], hidden, bias=bias[0]), torch.nn.Linear(hidden, out_dim[n], bias=bias[1])] + ) + ) + # Copy activation function for hidden layer and output layer + self.activation = activation + # Initialise all weights + with torch.no_grad(): + for from_layer in range(2): + for n in range(self.N): + # Set weights to xavier initalisation + torch.nn.init.xavier_normal_(self.w[n][from_layer].weight) + # Set biases to 0 + if bias[from_layer]: + self.w[n][from_layer].bias.fill_(0.0) + + def set_weights(self, from_layer, value): + """ + Set weights for each module from a given layer to a given value. + Parameters + ----------- + from_layer: layer from which to set weights + value: value to set weights to + """ + # If single value is provided: copy it for each module + if type(value) is not list: + input_value = [value for n in range(self.N)] + else: + input_value = value + # Run through all modules and set weights starting from requested layer to the specified value + with torch.no_grad(): + # MLP is setup as follows: w[module][layer] is Linear object, w[module][layer].weight is Parameter object for + # linear weights, w[module][layer].weight.data is tensor of weight values + for n in range(self.N): + # If a tensor is provided: copy the tensor to the weights + if type(input_value[n]) is torch.Tensor: + self.w[n][from_layer].weight.copy_(input_value[n]) + # If only a single value is provided: set that value everywhere + else: + self.w[n][from_layer].weight.fill_(input_value[n]) + + def forward(self, data): + """ + Run input data through MLP network. + Parameters + ----------- + data: input data + Returns + ----------- + output: output data + """ + # Make input data into list, if this network doesn't consist of modules + if self.is_list: + input_data = data + else: + input_data = [data] + # Run input through network for each module + output = [] + for n in range(self.N): + # Pass through first weights from input to hidden layer + module_output = self.w[n][0](input_data[n]) + # Apply hidden layer activation + if self.activation[0] is not None: + module_output = self.activation[0](module_output) + # Pass through second weights from hidden to output layer + module_output = self.w[n][1](module_output) + # Apply output layer activation + if self.activation[1] is not None: + module_output = self.activation[1](module_output) + # Transpose output again to go back to column vectors instead of row vectors + output.append(module_output) + # If this network doesn't consist of modules: select output from first module to return + if not self.is_list: + output = output[0] + # And return output + return output + + +class LSTM(torch.nn.Module): + """ + Class for LSTM with multiple layers. + """ + + def __init__(self, in_dim, hidden_dim, out_dim, n_layers=1, n_a=4): + """ + Initialise LSTM with multiple layers. + """ + # First call super class init function to set up torch.nn.Module style model and inherit it's functionality + super(LSTM, self).__init__() + # LSTM layer + self.lstm = torch.nn.LSTM(in_dim, hidden_dim, n_layers, batch_first=True) + # Hidden to output + self.lin = torch.nn.Linear(hidden_dim, out_dim) + # Copy number of actions, will be needed for input data vector + self.n_a = n_a + + def forward(self, data, prev_hidden=None): + """ + Run input data through LSTM network. + Parameters + ----------- + data: input data + prev_hidden: previous hidden state + Returns + ----------- + out: output data + lstm_hidden: hidden state + """ + # If previous hidden and cell state are not provided: initialise them randomly + if prev_hidden is None: + hidden_state = torch.randn(self.lstm.num_layers, data.shape[0], self.lstm.hidden_size) + cell_state = torch.randn(self.lstm.num_layers, data.shape[0], self.lstm.hidden_size) + prev_hidden = (hidden_state, cell_state) + # Run input through lstm + lstm_out, lstm_hidden = self.lstm(data, prev_hidden) + # Apply linear network to lstm output to get output: prediction at each timestep + lin_out = self.lin(lstm_out) + # And since we want a one-hot prediciton: do softmax on top + out = utils.softmax(lin_out) + # Return output and hidden state + return out, lstm_hidden + + def prepare_data(self, data_in): + """ + Prepare input data for LSTM network. + Parameters + ----------- + data_in: input data + Returns + ----------- + data: prepared input data + """ + # Transform list of actions of each step into batch of one-hot row vectors + actions = [ + torch.zeros((len(step[2]), self.n_a)).scatter_(1, torch.tensor(step[2]).unsqueeze(1), 1.0) for step in data_in + ] + # Concatenate observation and action together along column direction in each step + vectors = [torch.cat((step[1], action), dim=1) for step, action in zip(data_in, actions)] + # Then stack all these together along the second dimension, which is sequence length + data = torch.stack(vectors, dim=1) + # Return data in [batch_size, seq_len, input_dim] dimension as expected by lstm + return data + + +class Iteration: + """ + Class for storing all data from one iteration of the model. + """ + + def __init__( + self, + g=None, + x=None, + a=None, + L=None, + M=None, + g_gen=None, + p_gen=None, + x_gen=None, + x_logits=None, + x_inf=None, + g_inf=None, + p_inf=None, + ): + """ + Initialise iteration with all data from one iteration of the model. + """ + # Copy all inputs + self.g = g + self.x = x + self.a = a + self.L = L + self.M = M + self.g_gen = g_gen + self.p_gen = p_gen + self.x_gen = x_gen + self.x_logits = x_logits + self.x_inf = x_inf + self.g_inf = g_inf + self.p_inf = p_inf + + def correct(self): + """ + Calculate accuracy of model prediction for this iteration. + Returns + ----------- + accuracy: accuracy of model prediction for this iteration + """ + # Detach observation and all predictions + observation = self.x.detach().numpy() + predictions = [tensor.detach().numpy() for tensor in self.x_gen] + # Did the model predict the right observation in this iteration? + accuracy = [np.argmax(prediction, axis=-1) == np.argmax(observation, axis=-1) for prediction in predictions] + return accuracy + + def detach(self): + """ + Detach all tensors contained in this iteration. + Returns + ----------- + self: iteration with all tensors detached + """ + # Detach all tensors contained in this iteration + self.L = [tensor.detach() for tensor in self.L] + self.M = [tensor.detach() for tensor in self.M] + self.g_gen = [tensor.detach() for tensor in self.g_gen] + self.p_gen = [tensor.detach() for tensor in self.p_gen] + self.x_gen = [tensor.detach() for tensor in self.x_gen] + self.x_inf = [tensor.detach() for tensor in self.x_inf] + self.g_inf = [tensor.detach() for tensor in self.g_inf] + self.p_inf = [tensor.detach() for tensor in self.p_inf] + # Return self after detaching everything + return self diff --git a/neuralplayground/agents/whittington_2020_extras/whittington_2020_parameters.py b/neuralplayground/agents/whittington_2020_extras/whittington_2020_parameters.py new file mode 100644 index 00000000..71da21df --- /dev/null +++ b/neuralplayground/agents/whittington_2020_extras/whittington_2020_parameters.py @@ -0,0 +1,386 @@ +import numpy as np +import torch +from scipy.special import comb + + +def parameters(): + """ + Set all parameters for the TEM model. This is a function so that it can be called from other scripts, + e.g. to load parameters from a file. + """ + params = {} + # -- World parameters + # Does this world include the standing still action? + params["has_static_action"] = True + # Number of available actions, excluding the stand still action (since standing still has an action vector + # full of zeros, it won't add to the action vector dimension) + params["n_actions"] = 4 + # Bias for explorative behaviour to pick the same action again, to encourage straight walks + params["explore_bias"] = 2 + # Rate at which environments with shiny objects occur between training environments. Set to 0 for no + # shiny environments at all + params["shiny_rate"] = 0 + # Discount factor in calculating Q-values to generate shiny object oriented behaviour + params["shiny_gamma"] = 0.7 + # Inverse temperature for shiny object behaviour to pick actions based on Q-values + params["shiny_beta"] = 1.5 + # Number of shiny objects in the arena + params["shiny_n"] = 2 + # Number of times to return to a shiny object after finding it + params["shiny_returns"] = 15 + # Group all shiny parameters together to pass them to the world object + params["shiny"] = { + "gamma": params["shiny_gamma"], + "beta": params["shiny_beta"], + "n": params["shiny_n"], + "returns": params["shiny_returns"], + } + + # -- Traning parameters + # Number of walks to generate + params["train_it"] = 20000 + # Number of steps to roll out before backpropagation through time + params["n_rollout"] = 20 + # Saving interval + params["save_interval"] = 1000 + # Number of environments to save + params["n_envs_save"] = 6 + # Batch size: number of walks for training simultaneously + params["batch_size"] = 16 + # Other relevant parameters + params["state_density"] = 1 + # Minimum length of a walk on one environment. Walk lengths are sampled uniformly from a window that + # shifts down until its lower limit is walk_it_min at the end of training + params["walk_it_min"] = 25 + # Maximum length of a walk on one environment. Walk lengths are sampled uniformly from a window that + # starts with its upper limit at walk_it_max in the beginning of training, then shifts down + params["walk_it_max"] = 300 + # Width of window from which walk lengths are sampled: at any moment, new walk lengths are sampled + # window_center +/- 0.5 * walk_it_window where window_center shifts down + params["walk_it_window"] = 0.2 * (params["walk_it_max"] - params["walk_it_min"]) + # Weights of prediction losses + params["loss_weights_x"] = 1 + # Weights of grounded location losses + params["loss_weights_p"] = 1 + # Weights of abstract location losses + params["loss_weights_g"] = 1 + # Weights of regularisation losses + params["loss_weights_reg_g"] = 0.01 + params["loss_weights_reg_p"] = 0.02 + # Weights of losses: re-balance contributions of all losses + params["loss_weights"] = torch.tensor( + [ + params["loss_weights_p"], + params["loss_weights_p"], + params["loss_weights_x"], + params["loss_weights_x"], + params["loss_weights_x"], + params["loss_weights_g"], + params["loss_weights_reg_g"], + params["loss_weights_reg_p"], + ], + dtype=torch.float, + ) + # Number of backprop iters until latent parameter losses (L_p_g, L_p_x, L_g) are all fully weighted + params["loss_weights_p_g_it"] = 2000 + # Number of backptrop iters until regularisation losses are fully weighted + params["loss_weights_reg_p_it"] = 4000 + params["loss_weights_reg_g_it"] = 40000000 + # Number of backprop iters until eta is (rate of remembering) completely 'on' + params["eta_it"] = 16000 + # Number of backprop iters until lambda (rate of forgetting) is completely 'on' + params["lambda_it"] = 200 + # Determine how much to use an offset for the standard deviation of the inferred grounded location + # to reduce its influence + params["p2g_scale_offset"] = 0 + # Additional value to offset standard deviation of inferred grounded location when inferring new + # abstract location, to reduce influence in precision weighted mean + params["p2g_sig_val"] = 10000 + # Set number of iterations where offset scaling should be 0.5 + params["p2g_sig_half_it"] = 400 + # Set how fast offset scaling should decrease - after p2g_sig_half_it + p2g_sig_scale_it the offset + # scaling is down to ~0.25 (1/(1+e) to be exact) + params["p2g_sig_scale_it"] = 200 + # Maximum learning rate + params["lr_max"] = 9.4e-4 + # Minimum learning rate + params["lr_min"] = 8e-5 + # Rate of learning rate decay + params["lr_decay_rate"] = 0.5 + # Steps of learning rate decay + params["lr_decay_steps"] = 4000 + # Number of rollouts in each iteration + params["n_walks"] = generate_n_walk(params) + + # -- Model parameters + # Decide whether to sample, or assume no noise and simply take mean of all distributions + params["do_sample"] = False + # Decide whether to use inferred ground location while inferring new abstract location, instead of + # only previous grounded location (James's infer_g_type) + params["use_p_inf"] = True + # Decide whether to use seperate grid modules that recieve shiny information for object vector cells. + # To disable OVC, set this False, and set n_ovc to [0 for _ in range(len(params['n_g_subsampled']))] + params["separate_ovc"] = False + # Standard deviation for initial initial g (which will then be learned) + params["g_init_std"] = 0.5 + # Standard deviation to initialise hidden to output layer of MLP for inferring new abstract location + # from memory of grounded location + params["g_mem_std"] = 0.1 + # Hidden layer size of MLP for abstract location transitions + params["d_hidden_dim"] = 20 + + # ---- Neuron and module parameters + # Neurons for subsampled entorhinal abstract location f_g(g) for each frequency module + params["n_g_subsampled"] = [10, 10, 8, 6, 6] + # Neurons for object vector cells. Neurons will get new modules if object vector cell modules are + # separated; otherwise, they are added to existing abstract location modules. + # a) No additional modules, no additional object vector neurons (e.g. when not using shiny + # environments): [0 for _ in range(len(params['n_g_subsampled']))], and separate_ovc set to False + # b) No additional modules, but n additional object vector neurons in each grid module: + # [n for _ in range(len(params['n_g_subsampled']))], and separate_ovc set to False + # c) Additional separate object vector modules, with n, m neurons: [n, m], and separate_ovc set to True + params["n_ovc"] = [0 for _ in range(len(params["n_g_subsampled"]))] + # Add neurons for object vector cells. Add new modules if object vector cells get separate modules, + # or else add neurons to existing modules + params["n_g_subsampled"] = ( + params["n_g_subsampled"] + params["n_ovc"] + if params["separate_ovc"] + else [grid + ovc for grid, ovc in zip(params["n_g_subsampled"], params["n_ovc"])] + ) + # Number of hierarchical frequency modules for object vector cells + params["n_f_ovc"] = len(params["n_ovc"]) if params["separate_ovc"] else 0 + # Number of hierarchical frequency modules for grid cells + params["n_f_g"] = len(params["n_g_subsampled"]) - params["n_f_ovc"] + # Total number of modules + params["n_f"] = len(params["n_g_subsampled"]) + # Number of neurons of entorhinal abstract location g for each frequency + params["n_g"] = [3 * g for g in params["n_g_subsampled"]] + # Neurons for sensory observation x + params["n_x"] = 45 + # Neurons for compressed sensory experience x_c + params["n_x_c"] = 10 + # Neurons for temporally filtered sensory experience x for each frequency + params["n_x_f"] = [params["n_x_c"] for _ in range(params["n_f"])] + # Neurons for hippocampal grounded location p for each frequency + params["n_p"] = [g * x for g, x in zip(params["n_g_subsampled"], params["n_x_f"])] + # Initial frequencies of each module. For ease of interpretation (higher number = higher + # frequency) this is 1 - the frequency as James uses it + params["f_initial"] = [0.99, 0.3, 0.09, 0.03, 0.01] + # Add frequencies of object vector cell modules, if object vector cells get separate modules + params["f_initial"] = params["f_initial"] + params["f_initial"][0 : params["n_f_ovc"]] + + # ---- Memory parameters + # Use common memory for generative and inference network + params["common_memory"] = False + # Hebbian rate of forgetting + params["lambda"] = 0.9999 + # Hebbian rate of remembering + params["eta"] = 0.5 + # Hebbian retrieval decay term + params["kappa"] = 0.8 + # Number of iterations of attractor dynamics for memory retrieval + params["i_attractor"] = params["n_f_g"] + # Maximum iterations of attractor dynamics per frequency in inference model, so you can early + # stop low-frequency modules. Set to None for no early stopping + params["i_attractor_max_freq_inf"] = [params["i_attractor"] for _ in range(params["n_f"])] + # Maximum iterations of attractor dynamics per frequency in generative model, so you can early + # stop low-frequency modules. Don't early stop for object vector cell modules. + params["i_attractor_max_freq_gen"] = [params["i_attractor"] - freq_nr for freq_nr in range(params["n_f_g"])] + [ + params["i_attractor"] for _ in range(params["n_f_ovc"]) + ] + + # --- Connectivity matrices + # Set connections when forming Hebbian memory of grounded locations: from low frequency modules to high. + # High frequency modules come first (different from James!) + params["p_update_mask"] = torch.zeros((np.sum(params["n_p"]), np.sum(params["n_p"])), dtype=torch.float) + n_p = np.cumsum(np.concatenate(([0], params["n_p"]))) + # Entry M_ij (row i, col j) is the connection FROM cell i TO cell j. Memory is retrieved by + # h_t+1 = h_t * M, i.e. h_t+1_j = sum_i {connection from i to j * h_t_i} + for f_from in range(params["n_f"]): + for f_to in range(params["n_f"]): + # For connections that involve separate object vector modules: these are connected to all normal + # modules, but hierarchically between object vector modules + if f_from > params["n_f_g"] or f_to > params["n_f_g"]: + # If this is a connection between object vector modules: only allow for connection from + # low to high frequency + if f_from > params["n_f_g"] and f_to > params["n_f_g"]: + if params["f_initial"][f_from] <= params["f_initial"][f_to]: + params["p_update_mask"][n_p[f_from] : n_p[f_from + 1], n_p[f_to] : n_p[f_to + 1]] = 1.0 + # If this is a connection to between object vector and normal modules: allow any connections, + # in both directions + else: + params["p_update_mask"][n_p[f_from] : n_p[f_from + 1], n_p[f_to] : n_p[f_to + 1]] = 1.0 + # Else: this is a connection between abstract location frequency modules; only allow for connections + # if it goes from low to high frequency + else: + if params["f_initial"][f_from] <= params["f_initial"][f_to]: + params["p_update_mask"][n_p[f_from] : n_p[f_from + 1], n_p[f_to] : n_p[f_to + 1]] = 1.0 + # During memory retrieval, hierarchical memory retrieval of grounded location is implemented by early-stopping + # low-frequency memory updates, using a mask for updates at every retrieval iteration + params["p_retrieve_mask_inf"] = [torch.zeros(sum(params["n_p"])) for _ in range(params["i_attractor"])] + params["p_retrieve_mask_gen"] = [torch.zeros(sum(params["n_p"])) for _ in range(params["i_attractor"])] + # Build masks for each retrieval iteration + for mask, max_iters in zip( + [params["p_retrieve_mask_inf"], params["p_retrieve_mask_gen"]], + [params["i_attractor_max_freq_inf"], params["i_attractor_max_freq_gen"]], + ): + # For each frequency, we get the number of update iterations, and insert ones in the mask for those iterations + for f, max_i in enumerate(max_iters): + # Update masks up to maximum iteration + for i in range(max_i): + mask[i][n_p[f] : n_p[f + 1]] = 1.0 + # In path integration, abstract location frequency modules can influence the transition of other + # modules hierarchically (low to high). Set for each frequency module from which other frequencies + # input is received + params["g_connections"] = [ + [params["f_initial"][f_from] <= params["f_initial"][f_to] for f_from in range(params["n_f_g"])] + + [False for _ in range(params["n_f_ovc"])] + for f_to in range(params["n_f_g"]) + ] + # Add connections for separate object vector cell module: only between object vector cell modules - and make + # those hierarchical too + params["g_connections"] = params["g_connections"] + [ + [False for _ in range(params["n_f_g"])] + + [params["f_initial"][f_from] <= params["f_initial"][f_to] for f_from in range(params["n_f_g"], params["n_f"])] + for f_to in range(params["n_f_g"], params["n_f"]) + ] + + # ---- Static matrices + # Matrix for repeating abstract location g to do outer product with sensory information x with elementwise product. + # Also see (*) note at bottom + params["W_repeat"] = [ + torch.tensor(np.kron(np.eye(params["n_g_subsampled"][f]), np.ones((1, params["n_x_f"][f]))), dtype=torch.float) + for f in range(params["n_f"]) + ] + # Matrix for tiling sensory observation x to do outer product with abstract with elementwise product. + # Also see (*) note at bottom + params["W_tile"] = [ + torch.tensor(np.kron(np.ones((1, params["n_g_subsampled"][f])), np.eye(params["n_x_f"][f])), dtype=torch.float) + for f in range(params["n_f"]) + ] + # Table for converting one-hot to two-hot compressed representation + params["two_hot_table"] = [[0] * (params["n_x_c"] - 2) + [1] * 2] + # We need a compressed code for each possible observation, but it's impossible to have more compressed codes + # than "n_x_c choose 2" + for i in range(1, min(int(comb(params["n_x_c"], 2)), params["n_x"])): + # Copy previous code + code = params["two_hot_table"][-1].copy() + # Find latest occurrence of [0 1] in that code + swap = [index for index in range(len(code) - 1, -1, -1) if code[index : index + 2] == [0, 1]][0] + # Swap those to get new code + code[swap : swap + 2] = [1, 0] + # If the first one was swapped: value after swapped pair is 1 + if swap + 2 < len(code) and code[swap + 2] == 1: + # In that case: move the second 1 all the way back - reverse everything after the swapped pair + code[swap + 2 :] = code[: swap + 1 : -1] + # And append new code to array + params["two_hot_table"].append(code) + # Convert each code to column vector pytorch tensor + params["two_hot_table"] = [torch.tensor(code) for code in params["two_hot_table"]] + # Downsampling matrix to go from grid cells to compressed grid cells for indexing memories by simply taking + # only the first n_g_subsampled grid cells + params["g_downsample"] = [ + torch.cat([torch.eye(dim_out, dtype=torch.float), torch.zeros((dim_in - dim_out, dim_out), dtype=torch.float)]) + for dim_in, dim_out in zip(params["n_g"], params["n_g_subsampled"]) + ] + return params + + +# This specifies how parameters are updated at every backpropagation iteration/gradient update +def parameter_iteration(iteration, params): + """ + Update parameters at every backpropagation iteration/gradient update. + Parameters + ---------- + iteration : int + Current iteration/gradient update. + params : dict + Dictionary of parameters. + Returns + ------- + eta : float + Hebbian rate of remembering. + lamb : float + Hebbian rate of forgetting. + p2g_scale_offset : float + Scaling of variance offset for grounded location inference. + lr : float + Learning rate. + walk_length_center : float + Center of walk length window, within which the walk lenghts of new walks are uniformly sampled. + loss_weights : torch.tensor + Current loss weights. + """ + # Calculate eta (rate of remembering) and lambda (rate of forgetting) for Hebbian memory updates + eta = min((iteration + 1) / params["eta_it"], 1) * params["eta"] + lamb = min((iteration + 1) / params["lambda_it"], 1) * params["lambda"] + # Calculate current scaling of variance offset for ground location inference + p2g_scale_offset = 1 / (1 + np.exp((iteration - params["p2g_sig_half_it"]) / params["p2g_sig_scale_it"])) + # Calculate current learning rate + lr = max( + params["lr_min"] + + (params["lr_max"] - params["lr_min"]) * (params["lr_decay_rate"] ** (iteration / params["lr_decay_steps"])), + params["lr_min"], + ) + # Calculate center of walk length window, within which the walk lenghts of new walks are uniformly sampled + walk_length_center = ( + params["walk_it_max"] + - params["walk_it_window"] * 0.5 + - min((iteration + 1) / params["train_it"], 1) + * (params["walk_it_max"] - params["walk_it_min"] - params["walk_it_window"]) + ) + # Calculate current loss weights + L_p_g = min((iteration + 1) / params["loss_weights_p_g_it"], 1) * params["loss_weights_p"] + L_p_x = min((iteration + 1) / params["loss_weights_p_g_it"], 1) * params["loss_weights_p"] * (1 - p2g_scale_offset) + L_x_gen = params["loss_weights_x"] + L_x_g = params["loss_weights_x"] + L_x_p = params["loss_weights_x"] + L_g = min((iteration + 1) / params["loss_weights_p_g_it"], 1) * params["loss_weights_g"] + L_reg_g = (1 - min((iteration + 1) / params["loss_weights_reg_g_it"], 1)) * params["loss_weights_reg_g"] + L_reg_p = (1 - min((iteration + 1) / params["loss_weights_reg_p_it"], 1)) * params["loss_weights_reg_p"] + # And concatenate them in the order expected by the model + loss_weights = torch.tensor([L_p_g, L_p_x, L_x_gen, L_x_g, L_x_p, L_g, L_reg_g, L_reg_p]) + # Return all updated parameters + return eta, lamb, p2g_scale_offset, lr, walk_length_center, loss_weights + + +def generate_n_walk(params): + """ + Generate number of steps to roll out before backpropagation through time for each iteration. + Parameters + ---------- + params : dict + Dictionary of parameters. + Returns + ------- + n_walks : list + List of number of steps to roll out before backpropagation through time for each iteration. + """ + n_walks = [] + for iter in range(params["train_it"]): + n_steps = ( + params["walk_it_max"] + - params["walk_it_window"] * 0.5 + - min((iter + 1) / params["train_it"], 1) + * (params["walk_it_max"] - params["walk_it_min"] - params["walk_it_window"]) + ) + n_walks.append(round(n_steps / params["n_rollout"])) + return n_walks + + +# (*) Note on W_tile and W_repeat: +# W_tile and W_repeat are for calculating outer products then vector flattening by matrix multiplication then +# elementwise product: +# g = np.random.rand(4,1) +# x = np.random.rand(3,1) +# out1 = np.matmul(g,np.transpose(x)).reshape((4*3,1)) +# W_repeat = np.kron(np.eye(4),np.ones((3,1))) +# W_tile = np.kron(np.ones((4,1)),np.eye(3)) +# out2 = np.matmul(W_repeat,g) * np.matmul(W_tile,x) +# Or in the case of row vectors, which is what you'd do for batch calculation: +# g = g.T +# x = x.T +# out3 = np.matmul(np.transpose(g), x).reshape((1,4*3)) # Notice how this is not batch-proof! +# W_repeat = np.kron(np.eye(4), np.ones((1,3))) +# W_tile = np.kron(np.ones((1,4)),np.eye(3)) +# out4 = np.matmul(g, W_repeat) * np.matmul(x,W_tile) # This is batch-proof diff --git a/neuralplayground/agents/whittington_2020_extras/whittington_2020_utils.py b/neuralplayground/agents/whittington_2020_extras/whittington_2020_utils.py new file mode 100644 index 00000000..de5b353b --- /dev/null +++ b/neuralplayground/agents/whittington_2020_extras/whittington_2020_utils.py @@ -0,0 +1,330 @@ +import copy as cp +import datetime +import logging +import os + +import numpy as np +import torch + + +def inv_var_weight(mus, sigmas): + """ + Calculates tensors of inverse-variance weighted averages and tensors of inverse-variance weighted standard deviations. + Parameters + ---------- + mus : list of torch tensors + List of tensors of means of distributions + sigmas : list of torch tensors + List of tensors of standard deviations of distributions + """ + # Stack vectors together along first dimension + mus = torch.stack(mus, dim=0) + sigmas = torch.stack(sigmas, dim=0) + # Calculate inverse variance weighted variance from sum over reciprocal of squared sigmas + inv_var_var = 1.0 / torch.sum(1.0 / (sigmas**2), dim=0) + # Calculate inverse variance weighted average + inv_var_avg = torch.sum(mus / (sigmas**2), dim=0) * inv_var_var + # Convert weigthed variance to sigma + inv_var_sigma = torch.sqrt(inv_var_var) + # And return results + return inv_var_avg, inv_var_sigma + + +def softmax(x): + """ + Calculates softmax of input tensor x. + """ + # Return torch softmax + return torch.nn.Softmax(dim=-1)(x) + + +def normalise(x): + """ + Normalises (L2) vector of input to unit norm, using torch normalise function. + """ + return torch.nn.functional.normalize(x, p=2, dim=-1) + + +def relu(x): + """ + Applies rectified linear activation unit to tensors of inputs, using torch relu funcion + """ + return torch.nn.functional.relu(x) + + +def leaky_relu(x): + """ + Applies leaky (meaning small negative slope instead of zeros) rectified linear activation unit to tensors + of inputs, using torch leaky relu funcion + """ + return torch.nn.functional.leaky_relu(x) + + +def squared_error(value, target): + """ + Calculates mean squared error (L2 norm) between (list of) tensors value and target by using torch MSE loss. + Parameters + ---------- + value : torch tensor + Tensor of values + target : torch tensor + Tensor of targets + Returns + ------- + loss : torch tensor + Tensor of mean squared errors + """ + # Return torch MSE loss + if type(value) is list: + loss = [0.5 * torch.sum(torch.nn.MSELoss(reduction="none")(value[i], target[i]), dim=-1) for i in range(len(value))] + else: + loss = 0.5 * torch.sum(torch.nn.MSELoss(reduction="none")(value, target), dim=-1) + return loss + + +def cross_entropy(value, target): + """ + Calculates binary cross entropy between tensors value and target by using torch cross entropy loss. + Parameters + ---------- + value : torch tensor + Tensor of values + target : torch tensor + Tensor of targets + Returns + ------- + loss : torch tensor + Tensor of binary cross entropies + """ + # Return torch BCE loss + if type(value) is list: + loss = [torch.nn.CrossEntropyLoss(reduction="none")(val, targ) for val, targ in zip(value, target)] + else: + loss = torch.nn.CrossEntropyLoss(reduction="none")(value, target) + return loss + + +def downsample(value, target_dim): + """ + Does downsampling by taking the an input vector, then averaging chunks to make it of requested dimension. + Parameters + ---------- + value : torch tensor + Tensor of values + target_dim : int + Target dimension of output vector + Returns + ------- + downsample : torch tensor + Tensor of values, downsampled to target_dim + """ + # Get input dimension + value_dim = value.size()[-1] + # Set places to break up input vector into chunks + edges = np.append(np.round(np.arange(0, value_dim, float(value_dim) / target_dim)), value_dim).astype(int) + # Create downsampling matrix + downsample = torch.zeros((value_dim, target_dim), dtype=torch.float) + # Fill downsampling matrix with chunks + for curr_entry in range(target_dim): + downsample[edges[curr_entry] : edges[curr_entry + 1], curr_entry] = torch.tensor( + 1.0 / (edges[curr_entry + 1] - edges[curr_entry]), dtype=torch.float + ) + # Do downsampling by matrix multiplication + return torch.matmul(value, downsample) + + +def make_directories(): + """ + Returns directories for storing data during a model training run. + """ + # Get current date for saving folder + date = datetime.datetime.today().strftime("%Y-%m-%d") + # Initialise the run and dir_check to create a new run folder within the current date + run = 0 + dir_check = True + # Initialise all pahts + train_path, model_path, save_path, script_path, run_path = None, None, None, None, None + # Find the current run: the first run that doesn't exist yet + while dir_check: + # Construct new paths + run_path = "../Summaries2/" + date + "/torch_run" + str(run) + "/" + train_path = run_path + "train" + model_path = run_path + "model" + save_path = run_path + "save" + script_path = run_path + "script" + envs_path = script_path + "/envs" + run += 1 + # And once a path doesn't exist yet: create new folders + if not os.path.exists(train_path) and not os.path.exists(model_path) and not os.path.exists(save_path): + os.makedirs(train_path) + os.makedirs(model_path) + os.makedirs(save_path) + os.makedirs(script_path) + os.makedirs(envs_path) + dir_check = False + # Return folders to new path + return run_path, train_path, model_path, save_path, script_path, envs_path + + +def set_directories(date, run): + """ + Returns directories for storing data during a model training run. + """ + # Initialise all pahts + train_path, model_path, save_path, script_path, run_path = None, None, None, None, None + # Find the current run: the first run that doesn't exist yet + run_path = "../Summaries/" + date + "/run" + str(run) + "/" + train_path = run_path + "train" + model_path = run_path + "model" + save_path = run_path + "save" + script_path = run_path + "script" + envs_path = script_path + "/envs" + # Return folders to new path + return run_path, train_path, model_path, save_path, script_path, envs_path + + +def make_logger(run_path): + """ + Creates a logger object for logging training progress. + Parameters + ---------- + run_path : str + Path to the run folder + Returns + ------- + logger : logger object + Logger object for logging training progress + """ + # Create new logger + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + # Remove anly existing handlers so you don't output to old files, or to new files twice + logger.handlers = [] + # Create a file handler, but only if the handler does + handler = logging.FileHandler(run_path + "report.log") + handler.setLevel(logging.INFO) + # Create a logging format + formatter = logging.Formatter("%(asctime)s: %(message)s") + handler.setFormatter(formatter) + # Add the handlers to the logger + logger.addHandler(handler) + # Return the logger object + return logger + + +def prepare_data_maps(data, prev_cell_maps, positions, pars): + """ + Prepare data for online cell normalisation. + Parameters + ---------- + data : list of torch tensors + List of tensors of data + prev_cell_maps : list of torch tensors + List of tensors of previous cell maps + positions : list of torch tensors + List of tensors of positions + pars : dict + Dictionary of parameters + Returns + ------- + cell_list : list of torch tensors + List of tensors of cell maps + positions : list of torch tensors + List of tensors of positions + """ + gs, ps, position = data + gs_all, ps_all = prev_cell_maps + + g1s = np.transpose(np.array(cp.deepcopy(gs)), [1, 2, 0]) + p1s = np.transpose(np.array(cp.deepcopy(ps)), [1, 2, 0]) + # pos_to = position[:][1:pars['n_rollout'] + 1] + pos_to = position + + gs_all = cell_norm_online(g1s, pos_to, gs_all, pars) + ps_all = cell_norm_online(p1s, pos_to, ps_all, pars) + + cell_list = [gs_all, ps_all] + + return cell_list, positions + + +def cell_norm_online(cells, positions, current_cell_mat, pars): + """ + Online cell normalisation. + Parameters + ---------- + cells : list of torch tensors + List of tensors of cells + positions : list of torch tensors + List of tensors of positions + current_cell_mat : list of torch tensors + List of tensors of current cell maps + pars : dict + Dictionary of parameters + Returns + ------- + new_cell_mat : list of torch tensors + List of tensors of new cell maps + """ + # for separate environments within each batch + envs = pars["diff_env_batches_envs"] + n_states = pars["n_states_world"] + n_envs_save = pars["n_envs_save"] + + num_cells = np.shape(cells)[1] + n_trials = np.shape(cells)[2] + + cell_mat = [np.zeros((n_states[envs[env]], num_cells)) for env in range(n_envs_save)] + + new_cell_mat = [None] * n_envs_save + + for env in range(n_envs_save): + for ii in range(n_trials): + position = int(positions[ii][env]["id"]) + cell_mat[env][position, :] += cells[env, :, ii] + try: + new_cell_mat[env] = cell_mat[env] + current_cell_mat[env] + except (ValueError, TypeError): + new_cell_mat[env] = cell_mat[env] + + return new_cell_mat + + +def check_wall(pre_state, new_state, wall, wall_closenes=1e-5, tolerance=1e-9): + """ + Parameters + ---------- + pre_state : (2,) 2d-ndarray + 2d position of pre-movement + new_state : (2,) 2d-ndarray + 2d position of post-movement + wall : (2, 2) ndarray + [[x1, y1], [x2, y2]] where (x1, y1) is on limit of the wall, (x2, y2) second limit of the wall + wall_closenes : float + how close the agent is allowed to be from the wall + + Returns + ------- + new_state: (2,) 2d-ndarray + corrected new state. If it is not crossing the wall, then the new_state stays the same, if the state cross the + wall, new_state will be corrected to a valid place without crossing the wall + cross_wall: bool + True if the change in state cross a wall + """ + + # Check if the line of the wall and the line between the states cross + A = np.stack([np.diff(wall, axis=0)[0, :], -new_state + pre_state], axis=1) + b = pre_state - wall[0, :] + try: + intersection = np.linalg.inv(A) @ b + except Exception: + intersection = np.linalg.inv(A + np.identity(A.shape[0]) * tolerance) @ b + smaller_than_one = intersection <= 1 + larger_than_zero = intersection >= 0 + + # If condition is true, then the points cross the wall + cross_wall = np.alltrue(np.logical_and(smaller_than_one, larger_than_zero)) + if cross_wall: + new_state = (intersection[-1] - wall_closenes) * (new_state - pre_state) + pre_state + + return new_state, cross_wall diff --git a/neuralplayground/arenas/__init__.py b/neuralplayground/arenas/__init__.py index a24d3638..774aaca8 100644 --- a/neuralplayground/arenas/__init__.py +++ b/neuralplayground/arenas/__init__.py @@ -5,3 +5,5 @@ from .hafting_2008 import Hafting2008 from .sargolini_2006 import Sargolini2006, BasicSargolini2006 from .wernle_2018 import Wernle2018, MergingRoom +from .batch_environment import BatchEnvironment +from .discritized_objects import DiscreteObjectEnvironment diff --git a/neuralplayground/arenas/batch_environment.py b/neuralplayground/arenas/batch_environment.py new file mode 100644 index 00000000..37b0ae9f --- /dev/null +++ b/neuralplayground/arenas/batch_environment.py @@ -0,0 +1,272 @@ +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np + +from neuralplayground.arenas.arena_core import Environment +from neuralplayground.arenas.simple2d import Simple2D + + +class BatchEnvironment(Environment): + def __init__(self, environment_name: str = "BatchEnv", env_class: object = Simple2D, batch_size: int = 1, **env_kwargs): + """ + Initialise a batch of environments. This is useful for training a single agent on multiple environments simultaneously. + Parameters + ---------- + environment_name: str + Name of the environment + env_class: object + Class of the environment + batch_size: int + Number of environments in the batch + **env_kwargs: dict + Keyword arguments for the environment + """ + super().__init__(environment_name, **env_kwargs) + self.batch_size = batch_size + self.batch_x_limits = env_kwargs["arena_x_limits"] + self.batch_y_limits = env_kwargs["arena_y_limits"] + self.use_behavioural_data = env_kwargs["use_behavioural_data"] + self.environments = [] + for i in range(self.batch_size): + env_kwargs["arena_x_limits"] = self.batch_x_limits[i] + env_kwargs["arena_y_limits"] = self.batch_y_limits[i] + self.environments.append(env_class(**env_kwargs)) + + self.room_widths = [np.diff(self.environments[i].arena_x_limits)[0] for i in range(self.batch_size)] + self.room_depths = [np.diff(self.environments[i].arena_y_limits)[0] for i in range(self.batch_size)] + self.state_densities = [self.environments[i].state_density for i in range(self.batch_size)] + + def reset(self, random_state: bool = True, custom_state: np.ndarray = None): + """ + Reset the environment + Parameters + ---------- + random_state: bool + If True, the agent will be placed in a random state + custom_state: np.ndarray + If not None, the agent will be placed in the state specified by custom_state + Returns + ------- + all_observations: list of np.ndarray + List of observations for each environment in the batch + all_states: list of np.ndarray + List of states for each environment in the batch + """ + self.global_steps = 0 + self.global_time = 0 + self.history = [] + + all_observations = [] + all_states = [] + for i, env in enumerate(self.environments): + env_obs, env_state = env.reset(random_state=random_state, custom_state=custom_state) + all_states.append(env_state) + all_observations.append(env_obs) + + return all_observations, all_states + + def step(self, actions: np.ndarray, normalize_step: bool = False): + """ + Step the environment + Parameters + ---------- + actions: np.ndarray + Array of actions for each environment in the batch + normalize_step: bool + If True, the agent will be placed in the state specified by custom_state + Returns + ------- + all_observations: list of np.ndarray + List of observations for each environment in the batch + all_states: list of np.ndarray + List of states for each environment in the batch + """ + all_observations = [] + all_states = [] + all_allowed = True + for batch, env in enumerate(self.environments): + action = actions[batch] + env_obs, env_state = env.step(action, normalize_step) + if self.use_behavioural_data: + if env.state[0] == env.old_state[0]: + all_allowed = False + else: + if env.state[0] == env.old_state[0] and action != [0, 0]: + all_allowed = False + all_observations.append(env_obs) + all_states.append(env_state) + + if not all_allowed: + for env in self.environments: + env.state = env.old_state + else: + self.history.append([env.transition for env in self.environments]) + + return all_observations, all_states + + def plot_trajectory( + self, history_data: list = None, ax=None, return_figure: bool = False, save_path: str = None, plot_every: int = 1 + ): + """Plot the Trajectory of the agent in the environment + Parameters + ---------- + history_data: list of interactions + if None, use history data saved as attribute of the arena, use custom otherwise + ax: mpl.axes._subplots.AxesSubplot (matplotlib axis from subplots) + axis from subplot from matplotlib where the trajectory will be plotted. + return_figure: bool + If true, it will return the figure variable generated to make the plot + save_path: str, list of str, tuple of str + saving path of the generated figure, if None, no figure is saved + Returns + ------- + ax: mpl.axes._subplots.AxesSubplot (matplotlib axis from subplots) + Modified axis where the trajectory is plotted + f: matplotlib.figure + if return_figure parameters is True + """ + env = self.environments[0] + # Use or not saved history + if history_data is None: + history_data = [his[0] for his in self.history] + + # Generate Figure + if ax is None: + f, ax = plt.subplots(1, 1, figsize=(8, 6)) + + # Draw walls + for wall in env.default_walls: + ax.plot(wall[:, 0], wall[:, 1], "C3", lw=3) + + # Draw custom walls + for wall in env.custom_walls: + ax.plot(wall[:, 0], wall[:, 1], "C0", lw=3) + + # Making the trajectory plot roughly square to show structure of the arena better + lower_lim, upper_lim = np.amin(env.arena_limits), np.amax(env.arena_limits) + ax.set_xlim([lower_lim, upper_lim]) + ax.set_ylim([lower_lim, upper_lim]) + + # Make plot of positions + if len(history_data) != 0: + state_history = [s["state"][-1] for s in history_data] + next_state_history = [s["next_state"][-1] for s in history_data] + state_history[0] + next_state_history[-1] + + cmap = mpl.cm.get_cmap("plasma") + norm = plt.Normalize(0, len(state_history)) + + aux_x = [] + aux_y = [] + for i, s in enumerate(state_history): + if i % plot_every == 0: + if i + plot_every >= len(state_history): + break + x_ = [s[0], state_history[i + plot_every][0]] + y_ = [s[1], state_history[i + plot_every][1]] + aux_x.append(s[0]) + aux_y.append(s[1]) + sc = ax.plot(x_, y_, "-", color=cmap(norm(i)), alpha=0.6) + + sc = ax.scatter(aux_x, aux_y, c=np.arange(len(aux_x)), vmin=0, vmax=len(aux_x), cmap="plasma", alpha=0.6, s=0.1) + cbar = plt.colorbar(sc, ax=ax, ticks=[0, len(state_history)]) + cbar.ax.set_ylabel("N steps", rotation=270, fontsize=16) + cbar.ax.set_yticklabels([0, len(state_history)], fontsize=16) + + if save_path is not None: + plt.savefig(save_path, bbox_inches="tight") + + if return_figure: + return ax, f + else: + return ax + + def collect_environment_info(self, model_input, history, environments): + """ + Collect information about the environment for each step of the trajectory. + Parameters + ---------- + model_input: list of np.ndarray + List of model inputs for each step of the trajectory + history: list of np.ndarray + List of histories for each step of the trajectory + environments: list of dict + List of environments for each step of the trajectory + Returns + ------- + environments: list of dict + List of environments for each step of the trajectory + """ + for step in range(len(model_input)): + id = model_input[step][0][0]["id"] + if not any(d["id"] == id for d in environments[0]): + x, y = history[step][0][-1][0], history[step][0][-1][1] + + # Round the (x, y) coordinates to the center of the nearest state + rounded_x, rounded_y = self.round_to_nearest_state_center(x, y) + + # Normalize the rounded coordinates + normalized_x, normalized_y = self.normalize_coordinates(rounded_x, rounded_y) + + loc_dict = { + "id": id, + "observation": int(np.argmax(model_input[step][1])), + "x": normalized_x, + "y": normalized_y, + "shiny": None, + } + environments[0].append(loc_dict) + + environments[0] = sorted(environments[0], key=lambda x: x["id"]) + + return environments + + def round_to_nearest_state_center(self, x, y): + """ + Round the (x, y) coordinates to the center of the nearest state. + Parameters + ---------- + x: float + x coordinate + y: float + y coordinate + Returns + ------- + rounded_x: float + x coordinate rounded to the center of the nearest state + rounded_y: float + y coordinate rounded to the center of the nearest state + """ + state_width = 1 / self.state_densities[0] + state_depth = 1 / self.state_densities[1] + + half_state_width = state_width / 2 + half_state_depth = state_depth / 2 + + rounded_x = round((x + half_state_width) / state_width) * state_width - half_state_width + rounded_y = round((y + half_state_depth) / state_depth) * state_depth - half_state_depth + + return rounded_x, rounded_y + + def normalize_coordinates(self, x, y): + """ + Normalize the (x, y) coordinates to the range [0, 1]. + Parameters + ---------- + x: float + x coordinate + y: float + y coordinate + Returns + ------- + normalized_x: float + x coordinate normalized to the range [0, 1] + normalized_y: float + y coordinate normalized to the range [0, 1] + """ + x_min, x_max = self.batch_x_limits[0][0], self.batch_x_limits[0][1] + y_min, y_max = self.batch_y_limits[0][0], self.batch_y_limits[0][1] + normalized_x = (x - x_min) / (x_max - x_min) + normalized_y = (y - y_min) / (y_max - y_min) + return normalized_x, normalized_y diff --git a/neuralplayground/arenas/discritized_objects.py b/neuralplayground/arenas/discritized_objects.py new file mode 100644 index 00000000..6f0d35b1 --- /dev/null +++ b/neuralplayground/arenas/discritized_objects.py @@ -0,0 +1,370 @@ +import random + +import matplotlib.pyplot as plt +import numpy as np + +from neuralplayground.arenas.arena_core import Environment +from neuralplayground.utils import check_crossing_wall + + +class DiscreteObjectEnvironment(Environment): + """ + Arena class which accounts for discrete sensory objects, inherits from the Simple2D class. + + Methods + ------ + __init__(self, environment_name='DiscreteObject', **env_kwargs): + Initialize the class. env_kwargs arguments are specific for each of the child environments and + described in their respective class. + reset(self): + Re-initialize state and global counters. Resets discrete objects at each state. + generate_objects(self): + Randomly distribute objects (one-hot encoded vectors) at each discrete state within the environment + make_observation(self, step): + Convert an (x,y) position into an observation of an object + pos_to_state(self, step): + Convert an (x,y) position to a discretised state index + plot_objects(self, history_data=None, ax=None, return_figure=False): + + Attributes + ---------- + state: array + Empty array for this abstract class + history: list + Contains transition history + env_kwags: dict + Arguments given to the init method + global_steps: int + Number of calls to step method, set to 0 when calling reset + global_time: float + Time simulating environment through step method, then global_time = time_step_size * global_steps + number_object: int + The number of possible objects present at any state + room_width: int + Size of the environment in the x direction + room_depth: int + Size of the environment in the y direction + state_density: int + The density of discrete states in the environment + """ + + def __init__( + self, + recording_index: int = None, + environment_name: str = "DiscreteObject", + verbose: bool = False, + experiment_class: str = None, + **env_kwargs, + ): + """ + Initialize the class. env_kwargs arguments are specific for each of the child environments and + described in their respective class. + Parameters + ---------- + environment_name: str + Name of the environment + verbose: bool + If True, print information about the environment + experiment_class: str + Name of the class of the experiment to use + **env_kwargs: + Arguments specific to each environment + """ + super().__init__(environment_name, **env_kwargs) + self.environment_name = environment_name + self.use_behavioral_data = env_kwargs["use_behavioural_data"] + self.experiment = experiment_class( + experiment_name=self.environment_name, + data_path=env_kwargs["data_path"], + recording_index=recording_index, + verbose=verbose, + ) + if self.use_behavioral_data: + self.state_dims_labels = ["x_pos", "y_pos", "head_direction_x", "head_direction_y"] + self.arena_limits = self.experiment.arena_limits + self.arena_x_limits = self.arena_limits[0].astype(int) + self.arena_y_limits = self.arena_limits[1].astype(int) + else: + self.state_dims_labels = ["x_pos", "y_pos"] + self.arena_x_limits = env_kwargs["arena_x_limits"] + self.arena_y_limits = env_kwargs["arena_y_limits"] + + self.n_objects = env_kwargs["n_objects"] + self.state_density = env_kwargs["state_density"] + self.arena_limits = np.array( + [[self.arena_x_limits[0], self.arena_x_limits[1]], [self.arena_y_limits[0], self.arena_y_limits[1]]] + ) + self.room_width = np.diff(self.arena_x_limits)[0] + self.room_depth = np.diff(self.arena_y_limits)[0] + self.agent_step_size = env_kwargs["agent_step_size"] + self._create_default_walls() + self._create_custom_walls() + self.wall_list = self.default_walls + self.custom_walls + + # Variables for discretised state space + self.resolution_w = int(self.state_density * self.room_width) + self.resolution_d = int(self.state_density * self.room_depth) + self.x_array = np.linspace( + -self.room_width / 2 + (1 / 2 * self.state_density), + self.room_width / 2 - (1 / 2 * self.state_density), + num=self.resolution_w, + ) + self.y_array = np.linspace( + -self.room_depth / 2 + (1 / 2 * self.state_density), + self.room_depth / 2 - (1 / 2 * self.state_density), + num=self.resolution_d, + ) + self.mesh = np.array(np.meshgrid(self.x_array, self.y_array)) + self.xy_combination = np.stack(np.array(np.meshgrid(self.x_array, self.y_array)), axis=-1) + self.ws = int(self.room_width * self.state_density) + self.hs = int(self.room_depth * self.state_density) + self.n_states = self.resolution_w * self.resolution_d + self.objects = np.empty(shape=(self.n_states, self.n_objects)) + + def reset(self, random_state=True, custom_state=None): + """ + Reset the environment variables and distribution of sensory objects. + Parameters + ---------- + random_state: bool + If True, sample a new position uniformly within the arena, use default otherwise + custom_state: np.ndarray + If given, use this array to set the initial state + + Returns + ---------- + observation: ndarray + Because this is a fully observable environment, make_observation returns the state of the environment + Array of the observation of the agent in the environment (Could be modified as the environments are evolves) + + self.state: ndarray (2,) + Vector of the x and y coordinate of the position of the animal in the environment + """ + + self.global_steps = 0 + self.global_time = 0 + self.history = [] + self.state = [-1, -1, [self.room_width + 1, self.room_depth + 1]] + if random_state: + pos = [ + np.random.uniform(low=self.arena_limits[0, 0], high=self.arena_limits[0, 1]), + np.random.uniform(low=self.arena_limits[1, 0], high=self.arena_limits[1, 1]), + ] + else: + pos = np.array([0, 0]) + + if custom_state is not None: + pos = np.array(custom_state) + + # Reset to first position recorded in this session + if self.use_behavioral_data: + pos, head_dir = self.experiment.position[0, :], self.experiment.head_direction[0, :] + custom_state = np.concatenate([pos, head_dir]) + + self.objects = self.generate_objects() + + # Fully observable environment, make_observation returns the state + observation = self.make_object_observation(pos) + self.state = observation + return observation, self.state + + def step(self, action: np.ndarray, normalize_step: bool = False, skip_every: int = 10): + """ + Runs the environment dynamics. Increasing global counters. Given some action, return observation, + new state and reward. + + Parameters + ---------- + action: ndarray (2,) + Array containing the action of the agent, in this case the delta_x and detla_y increment to position + normalize_step: bool + If true, the action is normalized to have unit size, then scaled by the agent step size + skip_every: int + When using behavioral data, the next state will be the position and head direction + "skip_every" recorded steps after the current one, essentially reducing the sampling rate + + Returns + ------- + reward: float + The reward that the animal receives in this state + new_state: ndarray + Update the state with the updated vector of coordinate x and y of position and head directions respectively + observation: ndarray + Array of the observation of the agent in the environment, in this case the sensory object. + """ + self.old_state = self.state.copy() + if self.use_behavioral_data: + # In this case, the action is ignored and computed from the step in behavioral data recorded from the experiment + if self.global_steps * skip_every >= self.experiment.position.shape[0] - 1: + self.global_steps = np.random.choice(np.arange(skip_every)) + # Time elapsed since last reset + self.global_time = self.global_steps * self.time_step_size + + # New state as "skip every" steps after the current one in the recording + new_pos_state = ( + self.experiment.position[self.global_steps * skip_every, :], + self.experiment.head_direction[self.global_steps * skip_every, :], + ) + new_pos_state = np.concatenate(new_pos_state) + else: + if action[0] == 0: + action_rev = np.array([0.0, -action[1]]) + else: + action_rev = action + if normalize_step and np.linalg.norm(action) > 0: + action_rev = action_rev / np.linalg.norm(action_rev) + new_pos_state = self.state[-1] + self.agent_step_size * action_rev + else: + new_pos_state = self.state[-1] + action_rev + new_pos_state, valid_action = self.validate_action(self.state[-1], action_rev, new_pos_state) + reward = self.reward_function(action, self.state[-1]) # If you get reward, it should be coded here + observation = self.make_object_observation(new_pos_state) + self.state = observation + self.transition = { + "action": action, + "state": self.old_state, + "next_state": self.state, + "reward": reward, + "step": self.global_steps, + } + # self.history.append(transition) + self._increase_global_step() + return observation, self.state + + def generate_objects(self): + """ + Generate objects in the environment. In this case, the objects are one-hot encoded vectors. + Returns + ------- + objects: ndarray (n_states, n_objects) + Array of the objects in the environment, one-hot encoded + """ + poss_objects = np.zeros(shape=(self.n_objects, self.n_objects)) + for i in range(self.n_objects): + for j in range(self.n_objects): + if j == i: + poss_objects[i][j] = 1 + # Generate landscape of objects in each environment + objects = np.zeros(shape=(self.n_states, self.n_objects)) + for i in range(self.n_states): + rand = random.randint(0, self.n_objects - 1) + objects[i, :] = poss_objects[rand] + return objects + + def make_object_observation(self, pos): + """ + Make an observation of the object in the environment at the current position. + Parameters + ---------- + pos: ndarray (2,) + Vector of the x and y coordinate of the position of the animal in the environment + Returns + ------- + observation: ndarray (n_objects,) + Array of the observation of the agent in the environment, in this case the sensory object. + """ + index = self.pos_to_state(np.array(pos)) + object = self.objects[index] + + return [index, object, pos] + + def pos_to_state(self, pos): + """ + Convert an (x,y) position to a discretised state index + Parameters + ---------- + pos: ndarray (2,) + Vector of the x and y coordinate of the position of the animal in the environment + Returns + ------- + index: int + Index of the state in the discretised state space + """ + if len(pos) > 2: + pos = pos[:2] + diff = (self.xy_combination - pos) ** 2 + dist = np.sum(diff**2, axis=-1) + index = np.argmin(dist) + return index + + def _create_default_walls(self): + """Generate walls to limit the arena based on the limits given in kwargs when initializing the object. + Each wall is presented by a matrix + [[xi, yi], + [xf, yf]] + where xi and yi are x y coordinates of one limit of the wall, and xf and yf are coordinates of the other limit. + Walls are added to default_walls list, to then merge it with custom ones. + See notebook with custom arena examples. + """ + self.default_walls = [] + self.default_walls.append( + np.array([[self.arena_limits[0, 0], self.arena_limits[1, 0]], [self.arena_limits[0, 0], self.arena_limits[1, 1]]]) + ) + self.default_walls.append( + np.array([[self.arena_limits[0, 1], self.arena_limits[1, 0]], [self.arena_limits[0, 1], self.arena_limits[1, 1]]]) + ) + self.default_walls.append( + np.array([[self.arena_limits[0, 0], self.arena_limits[1, 0]], [self.arena_limits[0, 1], self.arena_limits[1, 0]]]) + ) + self.default_walls.append( + np.array([[self.arena_limits[0, 0], self.arena_limits[1, 1]], [self.arena_limits[0, 1], self.arena_limits[1, 1]]]) + ) + + def _create_custom_walls(self): + """Custom walls method. In this case is empty since the environment is a simple square room + Override this method to generate more walls, see jupyter notebook with examples""" + self.custom_walls = [] + + def validate_action(self, pre_state, action, new_state): + """Check if the new state is crossing any walls in the arena. + + Parameters + ---------- + pre_state : (2,) 2d-ndarray + 2d position of pre-movement + new_state : (2,) 2d-ndarray + 2d position of post-movement + + Returns + ------- + new_state: (2,) 2d-ndarray + corrected new state. If it is not crossing the wall, then the new_state stays the same, if the state cross the + wall, new_state will be corrected to a valid place without crossing the wall + crossed_wall: bool + True if the change in state crossed a wall and was corrected + """ + crossed_wall = False + for wall in self.wall_list: + new_state, crossed = check_crossing_wall(pre_state=pre_state, new_state=new_state, wall=wall) + crossed_wall = crossed or crossed_wall + return new_state, crossed_wall + + # to be written again here + def plot_objects(self, history_data=None, ax=None, return_figure=False): + """Plot the Trajectory of the agent in the environment + + Parameters + ---------- + history_data: None + default to access to the saved history of positions in the environment + ax: None + default to create ax + Returns + ------- + Returns a plot of the trajectory of the animal in the environment + """ + if history_data is None: + history_data = self.history + if ax is None: + f, ax = plt.subplots(1, 1, figsize=(8, 6)) + + for wall in self.default_walls: + ax.plot(wall[:, 0], wall[:, 1], "C3", lw=3) + + for wall in self.custom_walls: + ax.plot(wall[:, 0], wall[:, 1], "C0", lw=3) + + if return_figure: + return f, ax + else: + return ax diff --git a/neuralplayground/experiments/sargolini_2006_data.py b/neuralplayground/experiments/sargolini_2006_data.py index ce5eb941..980b2e18 100644 --- a/neuralplayground/experiments/sargolini_2006_data.py +++ b/neuralplayground/experiments/sargolini_2006_data.py @@ -138,8 +138,8 @@ def _load_data(self): self.best_recording_index = 0 # Nice session recording as default # Arena limits from the experimental setting, first row x limits, second row y limits, in cm self.arena_limits = np.array([[-50.0, 50.0], [-50.0, 50.0]]) - data_path_list = glob.glob(self.data_path + "*.mat") - mice_ids = np.unique([dp.split("/")[-1][:5] for dp in data_path_list]) + data_path_list = glob.glob(os.path.join(self.data_path, "*.mat")) + mice_ids = np.unique([os.path.basename(dp)[:5] for dp in data_path_list]) # Initialize data dictionary, later handled by this object itself (so don't worry about this) self.data_per_animal = {} for m_id in mice_ids: