Skip to content

Commit

Permalink
comparison sargo
Browse files Browse the repository at this point in the history
  • Loading branch information
ClementineDomine committed Jul 15, 2023
1 parent e4eaaf9 commit b6c6c52
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 10,122 deletions.
10,249 changes: 162 additions & 10,087 deletions examples/comparisons_examples/Comparision_from_manadger.ipynb

Large diffs are not rendered by default.

35 changes: 0 additions & 35 deletions neuralplayground/agents/stachenfeld_2018.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,41 +396,6 @@ def plot_transition(self, save_path: str = None, ax: mpl.axes.Axes = None):
plt.close("all")
return ax

def plot_eigen(
self,
matrix: np.ndarray,
save_path: str,
eigen_vectors=(0, 1),
ax: mpl.axes.Axes = None,
):
""" "
Plot the matrix and the 4 largest modes of its eigen-decomposition
Parameters
----------
matrix: array
The matrix that will be plotted
eigen: np.ndarray
Which eigenvectors you would like to plot
save_path: string
Path to save the plot
"""
evals, evecs = np.linalg.eig(matrix)
if ax is None:
f, ax = plt.subplots(1, len(eigen_vectors), figsize=(4 * len(eigen_vectors), 5))
if len(eigen_vectors) == 1:
evecs_0 = evecs[:, eigen_vectors[0]].reshape(self.depth, self.width).real
make_plot_rate_map(evecs_0, ax, "Eig_0", "width", "depth", "Firing rate")
else:
for i, eig in enumerate(eigen_vectors):
evecs_0 = evecs[:, eig].reshape(self.depth, self.width).real
make_plot_rate_map(evecs_0, ax[i], "Eig" + str(eig), "width", "depth", "Firing rate")
if save_path is None:
pass
else:
plt.savefig(save_path, bbox_inches="tight")
return ax

def get_rate_map_matrix(
self,
sr_matrix,
Expand Down
2 changes: 2 additions & 0 deletions neuralplayground/backend/training_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ def default_training_loop(agent: AgentCore, env: Environment, n_steps: int):

obs, state = env.reset()
training_hist = []
obs = obs[:2]
for j in range(round(n_steps)):
# Observe to choose an action
action = agent.act(obs)
# Run environment for given action
obs, state, reward = env.step(action)
update_output = agent.update()
training_hist.append(update_output)
obs = obs[:2]
process_training_hist(training_hist)
return agent, env, dict

Expand Down

0 comments on commit b6c6c52

Please sign in to comment.