diff --git a/neuralplayground/arenas/discritized_objects.py b/neuralplayground/arenas/discritized_objects.py index 0720778..dc696fb 100644 --- a/neuralplayground/arenas/discritized_objects.py +++ b/neuralplayground/arenas/discritized_objects.py @@ -97,8 +97,8 @@ def __init__( self.arena_limits = np.array( [[self.arena_x_limits[0], self.arena_x_limits[1]], [self.arena_y_limits[0], self.arena_y_limits[1]]] ) - self.room_width = np.diff(self.arena_x_limits)[0] - self.room_depth = np.diff(self.arena_y_limits)[0] + self.room_width = np.diff(self.arena_x_limits)[0].item() + self.room_depth = np.diff(self.arena_y_limits)[0].item() self.agent_step_size = env_kwargs["agent_step_size"] self._create_default_walls() self._create_custom_walls() @@ -413,9 +413,15 @@ def render(self, history_length=30, display=True): history = self.history[-history_length:] ax = self.plot_trajectory(history_data=history, ax=ax) canvas.draw() - image = np.frombuffer(canvas.tostring_rgb(), dtype="uint8") - image = image.reshape(f.canvas.get_width_height()[::-1] + (3,)) - print(image.shape) + width, height = f.canvas.get_width_height() + # Get the RGBA buffer from the canvas + image = np.frombuffer(canvas.buffer_rgba(), dtype="uint8") + image = image.reshape((height, width, 4)) + # Remove the alpha channel (RGBA -> RGB) + image_rgb = image[:, :, :3] + # Convert RGB to BGR for OpenCV + image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) + print(image_bgr.shape) if display: - cv2.imshow("2D_env", image) + cv2.imshow("2D_env", image_bgr) cv2.waitKey(10) diff --git a/neuralplayground/arenas/simple2d.py b/neuralplayground/arenas/simple2d.py index d98e421..e023203 100644 --- a/neuralplayground/arenas/simple2d.py +++ b/neuralplayground/arenas/simple2d.py @@ -107,8 +107,8 @@ def __init__( [self.arena_y_limits[0], self.arena_y_limits[1]], ] ) - self.room_width = np.diff(self.arena_x_limits)[0] - self.room_depth = np.diff(self.arena_y_limits)[0] + self.room_width = np.diff(self.arena_x_limits)[0].item() + self.room_depth = np.diff(self.arena_y_limits)[0].item() self.observation_space = Box( low=np.array([self.arena_x_limits[0], self.arena_y_limits[0]]), high=np.array([self.arena_x_limits[1], self.arena_y_limits[1]]), @@ -354,9 +354,15 @@ def render(self, history_length=30, display=True): history = self.history[-history_length:] ax = self.plot_trajectory(history_data=history, ax=ax) canvas.draw() - image = np.frombuffer(canvas.tostring_rgb(), dtype="uint8") - image = image.reshape(f.canvas.get_width_height()[::-1] + (3,)) - print(image.shape) + width, height = f.canvas.get_width_height() + # Get the RGBA buffer from the canvas + image = np.frombuffer(canvas.buffer_rgba(), dtype="uint8") + image = image.reshape((height, width, 4)) + # Remove the alpha channel (RGBA -> RGB) + image_rgb = image[:, :, :3] + # Convert RGB to BGR for OpenCV + image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) + print(image_bgr.shape) if display: - cv2.imshow("2D_env", image) + cv2.imshow("2D_env", image_bgr) cv2.waitKey(10)