rl-fin / trading_env.py
bonadio's picture
Using my own version of anytrading
174772e
import gym
from gym import spaces
from gym.utils import seeding
import numpy as np
from enum import Enum
import matplotlib.pyplot as plt
class Actions(Enum):
Sell = 0
Buy = 1
Do_nothing = 2
class TradingEnv(gym.Env):
metadata = {'render.modes': ['human']}
def __init__(self, df, window_size, frame_bound):
assert df.ndim == 2
assert len(frame_bound) == 2
self.frame_bound = frame_bound
self.seed()
self.df = df
self.window_size = window_size
self.prices, self.signal_features = self._process_data()
self.shape = (window_size, self.signal_features.shape[1])
# spaces
self.action_space = spaces.Discrete(len(Actions))
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float64)
# episode
self._start_tick = self.window_size
self._end_tick = len(self.prices) - 1
self._done = None
self._current_tick = None
self._last_trade_tick = None
self._position = None
self._position_history = None
self._total_reward = None
self._total_profit = None
self._first_rendering = None
self.history = None
# fees
self.trade_fee_bid_percent = 0.0005 # unit
self.trade_fee_ask_percent = 0.0005 # unit
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def reset(self):
self._done = False
self._current_tick = self._start_tick
self._last_trade_tick = self._current_tick - 1
self._position = 0
self._position_history = (self.window_size * [None])
# self._position_history = (self.window_size * [None]) + [self._position]
self._total_reward = 0.
self._total_profit = 0.
self.history = {}
return self._get_observation()
def _calculate_reward(self, action):
step_reward = 0
current_price = self.prices[self._current_tick]
last_price = self.prices[self._current_tick - 1]
price_diff = current_price - last_price
# OPEN BUY - 1
if action == Actions.Buy.value and self._position == 0:
self._position = 1
step_reward += price_diff
self._last_trade_tick = self._current_tick - 1
self._position_history.append(1)
elif action == Actions.Buy.value and self._position > 0:
step_reward += 0
self._position_history.append(-1)
# CLOSE SELL - 4
elif action == Actions.Buy.value and self._position < 0:
self._position = 0
step_reward += -1 * (self.prices[self._current_tick -1] - self.prices[self._last_trade_tick])
self._total_profit += step_reward
self._position_history.append(4)
# OPEN SELL - 3
elif action == Actions.Sell.value and self._position == 0:
self._position = -1
step_reward += -1 * price_diff
self._last_trade_tick = self._current_tick - 1
self._position_history.append(3)
# CLOSE BUY - 2
elif action == Actions.Sell.value and self._position > 0:
self._position = 0
step_reward += self.prices[self._current_tick -1] - self.prices[self._last_trade_tick]
self._total_profit += step_reward
self._position_history.append(2)
elif action == Actions.Sell.value and self._position < 0:
step_reward += 0
self._position_history.append(-1)
# DO NOTHING - 0
elif action == Actions.Do_nothing.value and self._position > 0:
step_reward += price_diff
self._position_history.append(0)
elif action == Actions.Do_nothing.value and self._position < 0:
step_reward += -1 * price_diff
self._position_history.append(0)
elif action == Actions.Do_nothing.value and self._position == 0:
step_reward += -1 * abs(price_diff)
self._position_history.append(0)
return step_reward
def step(self, action):
self._done = False
self._current_tick += 1
if self._current_tick == self._end_tick:
self._done = True
step_reward = self._calculate_reward(action)
self._total_reward += step_reward
observation = self._get_observation()
info = dict(
total_reward = self._total_reward,
total_profit = self._total_profit,
position = self._position
)
self._update_history(info)
return observation, step_reward, self._done, info
def _get_observation(self):
return self.signal_features[(self._current_tick-self.window_size+1):self._current_tick+1]
def _update_history(self, info):
if not self.history:
self.history = {key: [] for key in info.keys()}
for key, value in info.items():
self.history[key].append(value)
def render(self, mode='human'):
window_ticks = np.arange(len(self._position_history))
plt.plot(self.prices)
open_buy = []
close_buy = []
open_sell = []
close_sell = []
do_nothing = []
for i, tick in enumerate(window_ticks):
if self._position_history[i] is None:
continue
if self._position_history[i] == 1:
open_buy.append(tick)
elif self._position_history[i] == 2 :
close_buy.append(tick)
elif self._position_history[i] == 3 :
open_sell.append(tick)
elif self._position_history[i] == 4 :
close_sell.append(tick)
elif self._position_history[i] == 0 :
do_nothing.append(tick)
plt.plot(open_buy, self.prices[open_buy], 'go', marker="^")
plt.plot(close_buy, self.prices[close_buy], 'go', marker="v")
plt.plot(open_sell, self.prices[open_sell], 'ro', marker="v")
plt.plot(close_sell, self.prices[close_sell], 'ro', marker="^")
plt.plot(do_nothing, self.prices[do_nothing], 'yo')
plt.suptitle(
"Total Reward: %.6f" % self._total_reward + ' ~ ' +
"Total Profit: %.6f" % self._total_profit
)
def close(self):
plt.close()
def save_rendering(self, filepath):
plt.savefig(filepath)
def pause_rendering(self):
plt.show()
def _process_data(self):
prices = self.df.loc[:, 'Close'].to_numpy()
prices[self.frame_bound[0] - self.window_size] # validate index (TODO: Improve validation)
prices = prices[self.frame_bound[0]-self.window_size:self.frame_bound[1]]
diff = np.insert(np.diff(prices), 0, 0)
signal_features = np.column_stack((prices, diff))
return prices, signal_features
def _update_profit(self, action):
trade = False
if ((action == Actions.Buy.value and self._position == Positions.Short) or
(action == Actions.Sell.value and self._position == Positions.Long)):
trade = True
if trade or self._done:
current_price = self.prices[self._current_tick]
last_trade_price = self.prices[self._last_trade_tick]
if self._position == Positions.Long:
shares = (self._total_profit * (1 - self.trade_fee_ask_percent)) / last_trade_price
self._total_profit = (shares * (1 - self.trade_fee_bid_percent)) * current_price
def max_possible_profit(self):
current_tick = self._start_tick
last_trade_tick = current_tick - 1
profit = 1.
while current_tick <= self._end_tick:
position = None
if self.prices[current_tick] < self.prices[current_tick - 1]:
while (current_tick <= self._end_tick and
self.prices[current_tick] < self.prices[current_tick - 1]):
current_tick += 1
position = Positions.Short
else:
while (current_tick <= self._end_tick and
self.prices[current_tick] >= self.prices[current_tick - 1]):
current_tick += 1
position = Positions.Long
if position == Positions.Long:
current_price = self.prices[current_tick - 1]
last_trade_price = self.prices[last_trade_tick]
shares = profit / last_trade_price
profit = shares * current_price
last_trade_tick = current_tick - 1
return profit