Commit
•
5b65f30
1
Parent(s):
c7c38dd
Update modeling_jat.py
Browse files- modeling_jat.py +7 -0
modeling_jat.py
CHANGED
@@ -711,6 +711,7 @@ class JatModel(GPTNeoPreTrainedModel):
|
|
711 |
action_space: Union[spaces.Box, spaces.Discrete] = None,
|
712 |
reward: Optional[float] = None,
|
713 |
deterministic: bool = False,
|
|
|
714 |
):
|
715 |
# Get the maximum sequence length
|
716 |
max_length = self.config.max_position_embeddings // 2
|
@@ -804,6 +805,12 @@ class JatModel(GPTNeoPreTrainedModel):
|
|
804 |
# We remove the last two values, as the inputs are [s_0, 0], [s_0, a_0, s_1, 0], [s_1, a_1, s_2, 0], ...
|
805 |
self._last_key_values = tuple(tuple(pkv[:, :, :-2] for pkv in pkvs) for pkvs in self._last_key_values)
|
806 |
|
|
|
|
|
|
|
|
|
|
|
|
|
807 |
# Return the predicted action
|
808 |
if continuous_actions is not None:
|
809 |
self.last_continuous_action = outputs.pred_actions[0, -1].cpu().tolist()
|
|
|
711 |
action_space: Union[spaces.Box, spaces.Discrete] = None,
|
712 |
reward: Optional[float] = None,
|
713 |
deterministic: bool = False,
|
714 |
+
context_window: Optional[int] = None,
|
715 |
):
|
716 |
# Get the maximum sequence length
|
717 |
max_length = self.config.max_position_embeddings // 2
|
|
|
805 |
# We remove the last two values, as the inputs are [s_0, 0], [s_0, a_0, s_1, 0], [s_1, a_1, s_2, 0], ...
|
806 |
self._last_key_values = tuple(tuple(pkv[:, :, :-2] for pkv in pkvs) for pkvs in self._last_key_values)
|
807 |
|
808 |
+
# Context window
|
809 |
+
if context_window is not None:
|
810 |
+
self._last_key_values = tuple(
|
811 |
+
tuple(pkv[:, :, -context_window:] for pkv in pkvs) for pkvs in self._last_key_values
|
812 |
+
)
|
813 |
+
|
814 |
# Return the predicted action
|
815 |
if continuous_actions is not None:
|
816 |
self.last_continuous_action = outputs.pred_actions[0, -1].cpu().tolist()
|