Skip to content

Commit

Permalink
name experiemnt, the edges representation and count params
Browse files Browse the repository at this point in the history
  • Loading branch information
ClementineDomine committed Nov 6, 2023
1 parent 0a8f2d9 commit 6c113ac
Show file tree
Hide file tree
Showing 4 changed files with 791 additions and 48 deletions.
85 changes: 46 additions & 39 deletions neuralplayground/agents/domine_2023.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,54 +106,17 @@ def __init__( # autogenerated
wandb.init(
project="graph-test",
entity="graph-brain",
name="Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M_%S"),
name=experiment_name + dateTimeObj.strftime("%d%b_%H_%M_%S"),
)
self.wandb_logs = {}
save_path = wandb.run.dir
os.mkdir(os.path.join(save_path, "results"))
self.save_path = os.path.join(save_path, "results")

wandb_logs = {
"train_on_shortest_path": train_on_shortest_path,
"resample": resample,

"batch_size_test": batch_size_test,
"nx_min_test": nx_min_test, # This is thought of the state density
"nx_max_test": nx_max_test, # This is thought of the state density
"batch_size": batch_size,
"nx_min": nx_min, # This is thought of the state density
"nx_max": nx_max,

"seed": seed,
"feature_position": feature_position,
"weighted": weighted,

"num_hidden": num_hidden,
"num_layers": num_layers,
"num_message_passing_steps": num_message_passing_steps,
"learning_rate": learning_rate,
"num_training_steps": num_training_steps,
}

if self.wandb_on:
wandb.log(wandb_logs)

else:
dateTimeObj = datetime.now()
save_path = os.path.join(Path(os.getcwd()).resolve(), "results")
os.mkdir(
os.path.join(
save_path, "Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M_%S")
)
)
self.save_path = os.path.join(
os.path.join(
save_path, "Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M_%S")
)
)

self.reset()
self.saving_run_parameters()


rng = jax.random.PRNGKey(self.seed)
self.rng_seq = rng_sequence_from_rng(rng)
Expand Down Expand Up @@ -181,11 +144,14 @@ def __init__( # autogenerated
)
net_hk = hk.without_apply_rng(hk.transform(forward))
params = net_hk.init(rng, self.graph)
param_count = sum(x.size for x in jax.tree_util.tree_leaves(params))
print("Total number of parameters: %d" % param_count)
self.params = params
optimizer = optax.adam(self.learning_rate)
opt_state = optimizer.init(self.params)
self.opt_state = opt_state


def compute_loss(params, inputs, targets):
outputs = net_hk.apply(params, inputs)
return jnp.mean((outputs[0].nodes - targets) ** 2)
Expand Down Expand Up @@ -231,6 +197,47 @@ def evaluate(params, inputs, target,wse_value=True,indices=None):

self._evaluate = evaluate

wandb_logs = {
"train_on_shortest_path": train_on_shortest_path,
"resample": resample,

"batch_size_test": batch_size_test,
"nx_min_test": nx_min_test, # This is thought of the state density
"nx_max_test": nx_max_test, # This is thought of the state density
"batch_size": batch_size,
"nx_min": nx_min, # This is thought of the state density
"nx_max": nx_max,

"seed": seed,
"feature_position": feature_position,
"weighted": weighted,

"num_hidden": num_hidden,
"num_layers": num_layers,
"num_message_passing_steps": num_message_passing_steps,
"learning_rate": learning_rate,
"num_training_steps": num_training_steps,
"param_count": param_count,
}

if self.wandb_on:
wandb.log(wandb_logs)

else:
dateTimeObj = datetime.now()
save_path = os.path.join(Path(os.getcwd()).resolve(), "results")
os.mkdir(
os.path.join(
save_path, "Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M_%S")
)
)
self.save_path = os.path.join(
os.path.join(
save_path, "Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M_%S")
)
)
self.saving_run_parameters()

def saving_run_parameters(self):
path = os.path.join(self.save_path, "run.py")
HERE = os.path.join(Path(os.getcwd()).resolve(), "domine_2023.py")
Expand Down
Loading

0 comments on commit 6c113ac

Please sign in to comment.