-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TEM branch merger #120
Open
LukeHollingsworth
wants to merge
67
commits into
main
Choose a base branch
from
whittington_2020
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
TEM branch merger #120
Changes from 63 commits
Commits
Show all changes
67 commits
Select commit
Hold shift + click to select a range
004bff2
Add options for Windows and Unix OS in README
LukeHollingsworth 86d403b
Merge branch 'main' of https://github.com/SainsburyWellcomeCentre/Neu…
LukeHollingsworth f84c0b3
Merge branch 'main' of https://github.com/SainsburyWellcomeCentre/Neu…
LukeHollingsworth 29b678c
adding experimental runs to TEM
LukeHollingsworth e36c848
batch environment example working with Simple2D
29f1dfb
default argument of BatchEnvironment() set to DiscreteObjectEnvironme…
f634212
default argument of BatchEnvironment() set to DiscreteObjectEnvironme…
ad5cab6
merge main
ClementineDomine 2b0f9d8
Update README.md - Centered logo
JarvisDevon 034e685
debugging state density plot
60abefa
Merge branch 'whittington_2020' of https://github.com/SainsburyWellco…
68422a1
pre-commit changes
55d740e
change TEM imports to not require torch install
LukeHollingsworth 2f78dee
note on installing dependencies on zsh shell
LukeHollingsworth 3fae384
merged main into whittington_2020
LukeHollingsworth a60d64b
introduce logging of training accuracies
LukeHollingsworth ff12f40
pre-commit changes
LukeHollingsworth 5c4fd53
added comments to TEM run file
LukeHollingsworth 82a34d7
merge from main
48058c3
batch trajectories and grids plotted
4cd1f7a
Simple2D & DiscreteObject examples added for BatchEnvironment
a8b07cf
attempting to fix large file problem
ccc584a
running TEM tests
f67b4c2
slurm updated
978f001
slurm updated
348a161
slurm change
aebb6fb
huge 50K run added
c5762df
huge 50K run added
723c16d
state density and history bugs sorted
0aab239
TEM state density bugs fixed
1672921
big high density run added
7128313
small TEM run
eab0cdf
state density mismatch fixed
36f5da1
small training run (without width 2) added
18c4abb
medium size run added
5bc718b
problem with state assignment fixed
ccb394e
reduced slurm memory pool
94b8ac8
reduced slurm memory pool
7de0832
updated test
5d33231
pre-commit run on all files
0d277b1
is the cluster broken or is it just me?
74990a6
trying cpu slurm
ca1c310
trying cpu slurm
a03ada0
trying cpu slurm
92de616
looped walks added
0e93183
looping walk
a681c6b
cpu slurm added
5b22e32
cpu slurm added
78cb5bb
trying to fix slurm bug
2c01eac
big memory run with longer walks
078d41f
new training config
ab297dc
formatted
cc0ad77
full var walks added
709fc82
trailing whitespace
0ead088
full length training
1fe31ab
recent TEM updates
2a95433
minor update
d2bdd14
test push
853d185
black precommit changes
4afe24c
precommit black
d423f89
pre-merge
3351dd8
premerge to main
d2cb6c3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 3921cf6
Merge remote-tracking branch 'origin/main' into whittington_2020
a6911d9
Merge branch 'whittington_2020' of https://github.com/SainsburyWellco…
2993e06
starting the cleaning process
62df4d6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,247 @@ | ||
import re | ||
from collections import defaultdict | ||
|
||
import matplotlib.pyplot as plt | ||
|
||
|
||
def parse_and_plot_run_log(file_path): | ||
iterations = [] | ||
losses = [] | ||
accuracies_p = [] | ||
accuracies_g = [] | ||
accuracies_gt = [] | ||
new_walks = [] | ||
|
||
iter_pattern = r"Finished backprop iter (\d+)" | ||
loss_pattern = r"Loss: ([\d.]+)\." # Note the added \. to catch the trailing period | ||
accuracy_pattern = r"Accuracy: <p> ([\d.]+)% <g> ([\d.]+)% <gt> ([\d.]+)%" | ||
new_walk_pattern = r"Iteration (\d+): new walk" | ||
|
||
with open(file_path, "r") as file: | ||
for line in file: | ||
iter_match = re.search(iter_pattern, line) | ||
if iter_match: | ||
iterations.append(int(iter_match.group(1))) | ||
|
||
loss_match = re.search(loss_pattern, line) | ||
if loss_match: | ||
losses.append(float(loss_match.group(1))) | ||
|
||
accuracy_match = re.search(accuracy_pattern, line) | ||
if accuracy_match: | ||
accuracies_p.append(float(accuracy_match.group(1))) | ||
accuracies_g.append(float(accuracy_match.group(2))) | ||
accuracies_gt.append(float(accuracy_match.group(3))) | ||
|
||
new_walk_match = re.search(new_walk_pattern, line) | ||
if new_walk_match: | ||
new_walks.append(int(new_walk_match.group(1))) | ||
|
||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12)) | ||
|
||
ax1.plot(iterations, losses) | ||
ax1.set_xlabel("Iteration") | ||
ax1.set_ylabel("Loss") | ||
ax1.set_title("Loss over Iterations") | ||
|
||
ax2.plot(iterations, accuracies_p, label="p accuracy") | ||
ax2.plot(iterations, accuracies_g, label="g accuracy") | ||
ax2.plot(iterations, accuracies_gt, label="gt accuracy") | ||
ax2.set_xlabel("Iteration") | ||
ax2.set_ylabel("Accuracy (%)") | ||
ax2.set_title("Accuracies over Iterations") | ||
ax2.legend() | ||
|
||
# Add vertical lines for new walks | ||
# for walk in new_walks: | ||
# ax1.axvline(x=walk, color='r', linestyle='--', alpha=0.5) | ||
# ax2.axvline(x=walk, color='r', linestyle='--', alpha=0.5) | ||
|
||
plt.tight_layout() | ||
plt.show() | ||
|
||
|
||
def analyse_log_file(file): | ||
# Regular expressions to match IDs and Objs lines | ||
id_pattern = re.compile(r"IDs: \[([^\]]+)\]") | ||
obj_pattern = re.compile(r"Objs: \[([^\]]+)\]") | ||
iter_pattern = re.compile(r"Finished backprop iter (\d+)") | ||
step_pattern = re.compile(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3}:") | ||
|
||
# Initialize data structures | ||
iteration_data = {} | ||
current_iteration = None | ||
id_to_obj_previous = {} | ||
current_step = 0 | ||
|
||
with open(file, "r") as file: | ||
for line in file: | ||
# Check for iteration number | ||
iter_match = iter_pattern.search(line) | ||
if iter_match: | ||
current_iteration = int(iter_match.group(1)) | ||
current_step = 0 # Reset step counter for new iteration | ||
continue # Proceed to next line | ||
|
||
# Check for step line (assumed to start with timestamp) | ||
if step_pattern.match(line): | ||
current_step += 1 | ||
|
||
# Extract IDs | ||
id_match = id_pattern.search(line) | ||
if id_match: | ||
ids = list(map(int, id_match.group(1).split(","))) | ||
continue # IDs are followed by Objs, proceed to next line | ||
|
||
# Extract Objs | ||
obj_match = obj_pattern.search(line) | ||
if obj_match: | ||
objs = list(map(int, obj_match.group(1).split(","))) | ||
|
||
# Ensure current_iteration is set | ||
if current_iteration is None: | ||
continue # Skip if iteration is not identified yet | ||
|
||
# Store IDs and Objs for this iteration and step | ||
if current_iteration not in iteration_data: | ||
iteration_data[current_iteration] = [] | ||
iteration_data[current_iteration].append((current_step, ids, objs)) | ||
|
||
# Now, process the data to find shifts with detailed information | ||
shifts = defaultdict(list) # Key: iteration, Value: list of shift details | ||
id_to_obj_current = {} | ||
|
||
sorted_iterations = sorted(iteration_data.keys()) | ||
|
||
for idx, iteration in enumerate(sorted_iterations): | ||
steps = iteration_data[iteration] | ||
# For each step in the iteration | ||
for step in steps: | ||
step_num, ids, objs = step | ||
# For each ID in the batch | ||
for batch_idx, (id_, obj) in enumerate(zip(ids, objs)): | ||
key = (batch_idx, id_) # Identify by batch index and ID | ||
if key in id_to_obj_previous: | ||
prev_info = id_to_obj_previous[key] | ||
prev_obj = prev_info["obj"] | ||
if obj != prev_obj: | ||
# Environment has changed for this batch member | ||
shifts[iteration].append( | ||
{ | ||
"batch_idx": batch_idx, | ||
"id": id_, | ||
"prev_obj": prev_obj, | ||
"new_obj": obj, | ||
"prev_iteration": prev_info["iteration"], | ||
"prev_step": prev_info["step"], | ||
"current_iteration": iteration, | ||
"current_step": step_num, | ||
} | ||
) | ||
# Update current mapping | ||
id_to_obj_current[key] = {"obj": obj, "iteration": iteration, "step": step_num} | ||
# After processing all steps in the iteration, update previous mapping | ||
id_to_obj_previous = id_to_obj_current.copy() | ||
id_to_obj_current.clear() | ||
|
||
# Output the iterations where shifts occurred with detailed information | ||
print("Environment shifts detected with detailed information:") | ||
with open("shifts_output.txt", "w") as output_file: | ||
for iteration in sorted(shifts.keys()): | ||
shift_list = shifts[iteration] | ||
if shift_list: | ||
output_file.write(f"\nIteration {iteration}: number of shifts = {len(shift_list)}\n") | ||
for shift in shift_list: | ||
output_file.write( | ||
f" Batch index {shift['batch_idx']}, ID {shift['id']} changed from " | ||
f"object {shift['prev_obj']} (Iteration {shift['prev_iteration']}, Step {shift['prev_step']}) " | ||
f"to object {shift['new_obj']} (Iteration {shift['current_iteration']},\ | ||
Step {shift['current_step']})\n" | ||
) | ||
|
||
|
||
def plot_loss_with_switches(log_file_path, output_file_path, large_switch_threshold): | ||
# Initialize lists to store data | ||
iterations = [] | ||
losses = [] | ||
large_switch_iterations = [] | ||
switch_counts = {} | ||
|
||
# Regular expressions to match lines in the log | ||
loss_pattern = re.compile(r"Loss: ([\d\.]+)") | ||
iteration_pattern = re.compile(r"Finished backprop iter (\d+)") | ||
# For the output file with switches | ||
switch_iteration_pattern = re.compile(r"Iteration (\d+): number of shifts = (\d+)") | ||
|
||
# Parse the training log file | ||
with open(log_file_path, "r") as log_file: | ||
current_iteration = None | ||
for line in log_file: | ||
# Check for iteration number | ||
iteration_match = iteration_pattern.search(line) | ||
if iteration_match: | ||
current_iteration = int(iteration_match.group(1)) | ||
iterations.append(current_iteration) | ||
continue # Move to the next line | ||
|
||
# Check for loss value | ||
loss_match = loss_pattern.search(line) | ||
if loss_match and current_iteration is not None: | ||
loss = float(loss_match.group(1)[:-1]) | ||
losses.append(loss) | ||
continue # Move to the next line | ||
|
||
# Parse the output file to get switch information | ||
with open(output_file_path, "r") as output_file: | ||
for line in output_file: | ||
# Check for switch iteration | ||
switch_iter_match = switch_iteration_pattern.match(line) | ||
if switch_iter_match: | ||
iteration = int(switch_iter_match.group(1)) | ||
num_shifts = int(switch_iter_match.group(2)) | ||
# Record iterations with shifts exceeding the threshold | ||
if num_shifts >= large_switch_threshold: | ||
large_switch_iterations.append(iteration) | ||
switch_counts[iteration] = num_shifts | ||
|
||
# Ensure the lists are aligned | ||
iterations = iterations[: len(losses)] | ||
|
||
# Plotting the loss over iterations | ||
plt.figure(figsize=(12, 6)) | ||
plt.plot(iterations, losses, label="Training Loss", color="blue") | ||
|
||
# Add markers for iterations with large switches | ||
for switch_iter in large_switch_iterations: | ||
if switch_iter in iterations: | ||
idx = iterations.index(switch_iter) | ||
plt.axvline(x=switch_iter, color="red", linestyle="--", alpha=0.5) | ||
# Optionally, add a text annotation for the number of shifts | ||
plt.text( | ||
switch_iter, | ||
losses[idx], | ||
f"{switch_counts[switch_iter]} shifts", | ||
rotation=90, | ||
va="bottom", | ||
ha="center", | ||
color="red", | ||
fontsize=8, | ||
) | ||
|
||
plt.title("Training Loss over Iterations with Large Batch Index Switches") | ||
plt.xlabel("Iteration") | ||
plt.ylabel("Loss") | ||
plt.legend() | ||
plt.grid(True) | ||
plt.show() | ||
|
||
|
||
parse_and_plot_run_log( | ||
"/Users/lukehollingsworth/Documents/PhD/SaxeLab/NeuralPlayground/NeuralPlayground/examples/agent_examples/begging_full/run.log" | ||
) | ||
# analyse_log_file('/Users/lukehollingsworth/Documents/PhD/SaxeLab/NeuralPlayground/NeuralPlayground/examples/agent_examples | ||
# /test/run.log') | ||
# plot_loss_with_switches('/Users/lukehollingsworth/Documents/PhD/SaxeLab/NeuralPlayground/NeuralPlayground/examples | ||
# /agent_examples/test/run.log', | ||
# '/Users/lukehollingsworth/Documents/PhD/SaxeLab/NeuralPlayground/NeuralPlayground/examples/agent_examples/ | ||
# test/output.txt', 50) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import importlib | ||
LukeHollingsworth marked this conversation as resolved.
Show resolved
Hide resolved
|
||
import os | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from neuralplayground.plotting import PlotSim | ||
|
||
# simulation_id = "examples/agent_examples/TEM_test_with_break" | ||
simulation_id = "TEM_test_witch_break" | ||
save_path = simulation_id + "/" | ||
plotting_loop_params = {"n_walk": 200} | ||
|
||
training_dict = pd.read_pickle(os.path.join(os.getcwd(), save_path, "params.dict")) | ||
model_weights = pd.read_pickle(os.path.join(save_path, "agent")) | ||
model_spec = importlib.util.spec_from_file_location("model", save_path + "whittington_2020_model.py") | ||
model = importlib.util.module_from_spec(model_spec) | ||
model_spec.loader.exec_module(model) | ||
params = pd.read_pickle(os.path.join(save_path, "agent_hyper")) | ||
tem = model.Model(params) | ||
tem.load_state_dict(model_weights) | ||
|
||
sim = PlotSim( | ||
simulation_id=simulation_id, | ||
agent_class=training_dict["agent_class"], | ||
agent_params=training_dict["agent_params"], | ||
env_class=training_dict["env_class"], | ||
env_params=training_dict["env_params"], | ||
plotting_loop_params=plotting_loop_params, | ||
) | ||
|
||
trained_agent, trained_env = sim.plot_sim(save_path, random_state=False, custom_state=[0.0, 0.0]) | ||
# trained_env.plot_trajectories(); | ||
|
||
max_steps_per_env = np.random.randint(4000, 5000, size=params["batch_size"]) | ||
current_steps = np.zeros(params["batch_size"], dtype=int) | ||
|
||
obs, state = trained_env.reset(random_state=False, custom_state=[0.0, 0.0]) | ||
for i in range(200): | ||
while trained_agent.n_walk < params["n_rollout"]: | ||
actions = trained_agent.batch_act(obs) | ||
obs, state, reward = trained_env.step(actions, normalize_step=True) | ||
trained_agent.update() | ||
|
||
current_steps += params["n_rollout"] | ||
finished_walks = current_steps >= max_steps_per_env | ||
if any(finished_walks): | ||
for env_i in np.where(finished_walks)[0]: | ||
trained_env.reset_env(env_i) | ||
trained_agent.prev_iter[0].a[env_i] = None | ||
|
||
max_steps_per_env[env_i] = params["n_rollout"] * np.random.randint( | ||
trained_agent.walk_length_center - params["walk_it_window"] * 0.5, | ||
trained_agent.walk_length_center + params["walk_it_window"] * 0.5, | ||
) | ||
current_steps[env_i] = 0 |
Large diffs are not rendered by default.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this really necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so. I had some problems early on, cloning the dev version of NeuralPlayground on both Windows and Mac OS. If this has been fixed, then this is redundant and I'll change it.