Skip to content

Commit

Permalink
dont run cv2.imshow in headless environments
Browse files Browse the repository at this point in the history
  • Loading branch information
niksirbi committed Oct 23, 2024
1 parent af83333 commit 8f82cc9
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
7 changes: 4 additions & 3 deletions neuralplayground/arenas/discritized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def plot_trajectory(
else:
return ax

def render(self, history_length=30):
def render(self, history_length=30, display=True):
"""Render the environment live through iterations"""
f, ax = plt.subplots(1, 1, figsize=(8, 6))
canvas = FigureCanvas(f)
Expand All @@ -416,5 +416,6 @@ def render(self, history_length=30):
image = np.frombuffer(canvas.tostring_rgb(), dtype="uint8")
image = image.reshape(f.canvas.get_width_height()[::-1] + (3,))
print(image.shape)
cv2.imshow("2D_env", image)
cv2.waitKey(10)
if display:
cv2.imshow("2D_env", image)
cv2.waitKey(10)
7 changes: 4 additions & 3 deletions neuralplayground/arenas/simple2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def plot_trajectory(
else:
return ax

def render(self, history_length=30):
def render(self, history_length=30, display=True):
"""Render the environment live through iterations as in OpenAI gym"""
f, ax = plt.subplots(1, 1, figsize=(8, 6))
canvas = FigureCanvas(f)
Expand All @@ -357,5 +357,6 @@ def render(self, history_length=30):
image = np.frombuffer(canvas.tostring_rgb(), dtype="uint8")
image = image.reshape(f.canvas.get_width_height()[::-1] + (3,))
print(image.shape)
cv2.imshow("2D_env", image)
cv2.waitKey(10)
if display:
cv2.imshow("2D_env", image)
cv2.waitKey(10)
2 changes: 1 addition & 1 deletion tests/arena_exp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_agent_interaction(self, init_env):
action = agent.act(obs)
# Run environment for given action
obs, state, reward = init_env[0].step(action)
init_env[0].render()
init_env[0].render(display=False)
init_env[0].plot_trajectory()


Expand Down

0 comments on commit 8f82cc9

Please sign in to comment.