Skip to content

Commit

Permalink
pre-merge
Browse files Browse the repository at this point in the history
  • Loading branch information
LukeHollingsworth committed Nov 20, 2024
1 parent 4afe24c commit d423f89
Show file tree
Hide file tree
Showing 4 changed files with 324 additions and 55 deletions.
247 changes: 247 additions & 0 deletions examples/agent_examples/plotting_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
import re
from collections import defaultdict

import matplotlib.pyplot as plt


def parse_and_plot_run_log(file_path):
iterations = []
losses = []
accuracies_p = []
accuracies_g = []
accuracies_gt = []
new_walks = []

iter_pattern = r"Finished backprop iter (\d+)"
loss_pattern = r"Loss: ([\d.]+)\." # Note the added \. to catch the trailing period
accuracy_pattern = r"Accuracy: <p> ([\d.]+)% <g> ([\d.]+)% <gt> ([\d.]+)%"
new_walk_pattern = r"Iteration (\d+): new walk"

with open(file_path, "r") as file:
for line in file:
iter_match = re.search(iter_pattern, line)
if iter_match:
iterations.append(int(iter_match.group(1)))

loss_match = re.search(loss_pattern, line)
if loss_match:
losses.append(float(loss_match.group(1)))

accuracy_match = re.search(accuracy_pattern, line)
if accuracy_match:
accuracies_p.append(float(accuracy_match.group(1)))
accuracies_g.append(float(accuracy_match.group(2)))
accuracies_gt.append(float(accuracy_match.group(3)))

new_walk_match = re.search(new_walk_pattern, line)
if new_walk_match:
new_walks.append(int(new_walk_match.group(1)))

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))

ax1.plot(iterations, losses)
ax1.set_xlabel("Iteration")
ax1.set_ylabel("Loss")
ax1.set_title("Loss over Iterations")

ax2.plot(iterations, accuracies_p, label="p accuracy")
ax2.plot(iterations, accuracies_g, label="g accuracy")
ax2.plot(iterations, accuracies_gt, label="gt accuracy")
ax2.set_xlabel("Iteration")
ax2.set_ylabel("Accuracy (%)")
ax2.set_title("Accuracies over Iterations")
ax2.legend()

# Add vertical lines for new walks
# for walk in new_walks:
# ax1.axvline(x=walk, color='r', linestyle='--', alpha=0.5)
# ax2.axvline(x=walk, color='r', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.show()


def analyse_log_file(file):
# Regular expressions to match IDs and Objs lines
id_pattern = re.compile(r"IDs: \[([^\]]+)\]")
obj_pattern = re.compile(r"Objs: \[([^\]]+)\]")
iter_pattern = re.compile(r"Finished backprop iter (\d+)")
step_pattern = re.compile(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3}:")

# Initialize data structures
iteration_data = {}
current_iteration = None
id_to_obj_previous = {}
current_step = 0

with open(file, "r") as file:
for line in file:
# Check for iteration number
iter_match = iter_pattern.search(line)
if iter_match:
current_iteration = int(iter_match.group(1))
current_step = 0 # Reset step counter for new iteration
continue # Proceed to next line

# Check for step line (assumed to start with timestamp)
if step_pattern.match(line):
current_step += 1

# Extract IDs
id_match = id_pattern.search(line)
if id_match:
ids = list(map(int, id_match.group(1).split(",")))
continue # IDs are followed by Objs, proceed to next line

# Extract Objs
obj_match = obj_pattern.search(line)
if obj_match:
objs = list(map(int, obj_match.group(1).split(",")))

# Ensure current_iteration is set
if current_iteration is None:
continue # Skip if iteration is not identified yet

# Store IDs and Objs for this iteration and step
if current_iteration not in iteration_data:
iteration_data[current_iteration] = []
iteration_data[current_iteration].append((current_step, ids, objs))

# Now, process the data to find shifts with detailed information
shifts = defaultdict(list) # Key: iteration, Value: list of shift details
id_to_obj_current = {}

sorted_iterations = sorted(iteration_data.keys())

for idx, iteration in enumerate(sorted_iterations):
steps = iteration_data[iteration]
# For each step in the iteration
for step in steps:
step_num, ids, objs = step
# For each ID in the batch
for batch_idx, (id_, obj) in enumerate(zip(ids, objs)):
key = (batch_idx, id_) # Identify by batch index and ID
if key in id_to_obj_previous:
prev_info = id_to_obj_previous[key]
prev_obj = prev_info["obj"]
if obj != prev_obj:
# Environment has changed for this batch member
shifts[iteration].append(
{
"batch_idx": batch_idx,
"id": id_,
"prev_obj": prev_obj,
"new_obj": obj,
"prev_iteration": prev_info["iteration"],
"prev_step": prev_info["step"],
"current_iteration": iteration,
"current_step": step_num,
}
)
# Update current mapping
id_to_obj_current[key] = {"obj": obj, "iteration": iteration, "step": step_num}
# After processing all steps in the iteration, update previous mapping
id_to_obj_previous = id_to_obj_current.copy()
id_to_obj_current.clear()

# Output the iterations where shifts occurred with detailed information
print("Environment shifts detected with detailed information:")
with open("shifts_output.txt", "w") as output_file:
for iteration in sorted(shifts.keys()):
shift_list = shifts[iteration]
if shift_list:
output_file.write(f"\nIteration {iteration}: number of shifts = {len(shift_list)}\n")
for shift in shift_list:
output_file.write(
f" Batch index {shift['batch_idx']}, ID {shift['id']} changed from "
f"object {shift['prev_obj']} (Iteration {shift['prev_iteration']}, Step {shift['prev_step']}) "
f"to object {shift['new_obj']} (Iteration {shift['current_iteration']},\
Step {shift['current_step']})\n"
)


def plot_loss_with_switches(log_file_path, output_file_path, large_switch_threshold):
# Initialize lists to store data
iterations = []
losses = []
large_switch_iterations = []
switch_counts = {}

# Regular expressions to match lines in the log
loss_pattern = re.compile(r"Loss: ([\d\.]+)")
iteration_pattern = re.compile(r"Finished backprop iter (\d+)")
# For the output file with switches
switch_iteration_pattern = re.compile(r"Iteration (\d+): number of shifts = (\d+)")

# Parse the training log file
with open(log_file_path, "r") as log_file:
current_iteration = None
for line in log_file:
# Check for iteration number
iteration_match = iteration_pattern.search(line)
if iteration_match:
current_iteration = int(iteration_match.group(1))
iterations.append(current_iteration)
continue # Move to the next line

# Check for loss value
loss_match = loss_pattern.search(line)
if loss_match and current_iteration is not None:
loss = float(loss_match.group(1)[:-1])
losses.append(loss)
continue # Move to the next line

# Parse the output file to get switch information
with open(output_file_path, "r") as output_file:
for line in output_file:
# Check for switch iteration
switch_iter_match = switch_iteration_pattern.match(line)
if switch_iter_match:
iteration = int(switch_iter_match.group(1))
num_shifts = int(switch_iter_match.group(2))
# Record iterations with shifts exceeding the threshold
if num_shifts >= large_switch_threshold:
large_switch_iterations.append(iteration)
switch_counts[iteration] = num_shifts

# Ensure the lists are aligned
iterations = iterations[: len(losses)]

# Plotting the loss over iterations
plt.figure(figsize=(12, 6))
plt.plot(iterations, losses, label="Training Loss", color="blue")

# Add markers for iterations with large switches
for switch_iter in large_switch_iterations:
if switch_iter in iterations:
idx = iterations.index(switch_iter)
plt.axvline(x=switch_iter, color="red", linestyle="--", alpha=0.5)
# Optionally, add a text annotation for the number of shifts
plt.text(
switch_iter,
losses[idx],
f"{switch_counts[switch_iter]} shifts",
rotation=90,
va="bottom",
ha="center",
color="red",
fontsize=8,
)

plt.title("Training Loss over Iterations with Large Batch Index Switches")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.show()


parse_and_plot_run_log(
"/Users/lukehollingsworth/Documents/PhD/SaxeLab/NeuralPlayground/NeuralPlayground/examples/agent_examples/begging_full/run.log"
)
# analyse_log_file('/Users/lukehollingsworth/Documents/PhD/SaxeLab/NeuralPlayground/NeuralPlayground/examples/agent_examples
# /test/run.log')
# plot_loss_with_switches('/Users/lukehollingsworth/Documents/PhD/SaxeLab/NeuralPlayground/NeuralPlayground/examples
# /agent_examples/test/run.log',
# '/Users/lukehollingsworth/Documents/PhD/SaxeLab/NeuralPlayground/NeuralPlayground/examples/agent_examples/
# test/output.txt', 50)
56 changes: 56 additions & 0 deletions examples/agent_examples/whittington_2020_debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import importlib
import os

import numpy as np
import pandas as pd

from neuralplayground.plotting import PlotSim

# simulation_id = "examples/agent_examples/TEM_test_with_break"
simulation_id = "TEM_test_witch_break"
save_path = simulation_id + "/"
plotting_loop_params = {"n_walk": 200}

training_dict = pd.read_pickle(os.path.join(os.getcwd(), save_path, "params.dict"))
model_weights = pd.read_pickle(os.path.join(save_path, "agent"))
model_spec = importlib.util.spec_from_file_location("model", save_path + "whittington_2020_model.py")
model = importlib.util.module_from_spec(model_spec)
model_spec.loader.exec_module(model)
params = pd.read_pickle(os.path.join(save_path, "agent_hyper"))
tem = model.Model(params)
tem.load_state_dict(model_weights)

sim = PlotSim(
simulation_id=simulation_id,
agent_class=training_dict["agent_class"],
agent_params=training_dict["agent_params"],
env_class=training_dict["env_class"],
env_params=training_dict["env_params"],
plotting_loop_params=plotting_loop_params,
)

trained_agent, trained_env = sim.plot_sim(save_path, random_state=False, custom_state=[0.0, 0.0])
# trained_env.plot_trajectories();

max_steps_per_env = np.random.randint(4000, 5000, size=params["batch_size"])
current_steps = np.zeros(params["batch_size"], dtype=int)

obs, state = trained_env.reset(random_state=False, custom_state=[0.0, 0.0])
for i in range(200):
while trained_agent.n_walk < params["n_rollout"]:
actions = trained_agent.batch_act(obs)
obs, state, reward = trained_env.step(actions, normalize_step=True)
trained_agent.update()

current_steps += params["n_rollout"]
finished_walks = current_steps >= max_steps_per_env
if any(finished_walks):
for env_i in np.where(finished_walks)[0]:
trained_env.reset_env(env_i)
trained_agent.prev_iter[0].a[env_i] = None

max_steps_per_env[env_i] = params["n_rollout"] * np.random.randint(
trained_agent.walk_length_center - params["walk_it_window"] * 0.5,
trained_agent.walk_length_center + params["walk_it_window"] * 0.5,
)
current_steps[env_i] = 0
51 changes: 21 additions & 30 deletions examples/agent_examples/whittington_2020_example.ipynb

Large diffs are not rendered by default.

25 changes: 0 additions & 25 deletions neuralplayground/agents/whittington_2020.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,34 +256,13 @@ def update(self):
self.final_model_input = model_input

forward = self.tem(model_input, self.prev_iter)
# chunk = [[step[0][0], np.argmax(step[1][0]), step[2][0]] for step in model_input]
for i in range(len(model_input)):
# self.logger.info(chunk[i])
ids = [step["id"] for step in model_input[i][0]]
objs = [int(np.argmax(step)) for step in model_input[i][1]]
actions = model_input[i][2]
self.logger.info("IDs: " + str(ids))
self.logger.info("Objs: " + str(objs))
self.logger.info("Actions: " + str(actions))
# if self.prev_iter is None:
# with open('OG_log.txt', 'a') as f:
# f.write('Walk number: ' + str(self.global_steps) + '\n')
# for c in model_input:
# f.write('ID: ' + str(c[0]) + '\n')
# f.write('Observation: ' + str([np.argmax(a) for a in c[1]]) + '\n')
# f.write('Action: ' + str(c[2]) + '\n')
# f.write('prev_iter: ' + str(self.prev_iter) + '\n')
# else:
# with open('OG_log.txt', 'a') as f:
# f.write('Walk number: ' + str(self.global_steps) + '\n')
# for c in model_input:
# f.write('ID: ' + str(c[0]) + '\n')
# f.write('Observation: ' + str([np.argmax(a) for a in c[1]]) + '\n')
# f.write('Action: ' + str(c[2]) + '\n')
# f.write('prev_iter.L: ' + str(self.prev_iter[0].L) + '\n')
# f.write('prev_iter.a: ' + str(self.prev_iter[0].a) + '\n')
# f.write('prev_iter.M: ' + str(self.prev_iter[0].M) + '\n')
# f.write('prev_iter.x: ' + str([torch.argmax(x) for x in self.prev_iter[0].x]) + '\n')

# Accumulate loss from forward pass
loss = torch.tensor(0.0)
Expand Down Expand Up @@ -346,10 +325,6 @@ def update(self):
self.accuracy_history["gt_accuracy"].append(acc_gt)
self.iter += 1

# if self.iter % 10 == 0:
# print("Iteration: ", self.iter)
# print("Accuracies: ", acc_p, acc_g, acc_gt)

def initialise(self):
"""
Generate random distribution of objects and intialise optimiser, logger and relevant variables
Expand Down

0 comments on commit d423f89

Please sign in to comment.