File size: 461 Bytes
6f3bdf9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
import gym
import numpy as np
from rl_algo_impls.wrappers.vectorable_wrapper import VecotarableWrapper
class VideoCompatWrapper(VecotarableWrapper):
def __init__(self, env: gym.Env) -> None:
super().__init__(env)
def render(self, mode="human", **kwargs):
r = super().render(mode=mode, **kwargs)
if mode == "rgb_array" and isinstance(r, np.ndarray) and r.dtype != np.uint8:
r = r.astype(np.uint8)
return r
|