Rendering elegant stock trading agents using Matplotlib and Gym

') file.

close()Now, let’s move onto creating our new render method.

It’s going to utilize our new StockTradingGraph class, that we haven’t written yet.

We’ll get to that next.

def render(self, mode='live', title=None, **kwargs): # Render the environment to the screen if mode == 'file': self.


get('filename', 'render.

txt')) elif mode == 'live': if self.

visualization == None: self.

visualization = StockTradingGraph(self.

df, title) if self.

current_step > LOOKBACK_WINDOW_SIZE: self.



current_step, self.

net_worth, self.

trades, window_size=LOOKBACK_WINDOW_SIZE)We are using kwargs here to pass the optional filename and title to the StockTradingGraph.

If you are unfamiliar with kwargs, it is basically a dictionary for passing optional keyword arguments to functions.

We also pass self.

trades for the visualization to render, but have not defined it yet, so let’s do that.

Back in our _take_action method, whenever we buy or sell shares, we are now going to add the details of that transaction to the self.

trades object, which we’ve initialized to [] in our reset method.

def _take_action(self, action): .

if action_type < 1: .

if shares_bought > 0: self.


append({'step': self.

current_step, 'shares': shares_bought, 'total': additional_cost, 'type': "buy"}) elif action_type < 2: .

if shares_sold > 0: self.


append({'step': self.

current_step, 'shares': shares_sold, 'total': shares_sold * current_price, 'type': "sell"})Now our StockTradingGraph has all of the information it needs to render the stock’s price history and trade volume, along with our agent’s net worth and any trades its made.

Let’s get started rendering our visualization.

First, we’ll define our StockTradingGraph and its __init__ method.

Here is where we will create our pyplot figure, and set up each of the subplots to be rendered to.

The date2num function is used to reformat dates into timestamps, necessary later in the rendering process.

import numpy as npimport matplotlibimport matplotlib.

pyplot as pltimport matplotlib.

dates as mdatesdef date2num(date): converter = mdates.

strpdate2num('%Y-%m-%d') return converter(date)class StockTradingGraph: """A stock trading visualization using matplotlib made to render OpenAI gym environments""" def __init__(self, df, title=None): self.

df = df self.

net_worths = np.

zeros(len(df['Date'])) # Create a figure on screen and set the title fig = plt.

figure() fig.

suptitle(title) # Create top subplot for net worth axis self.

net_worth_ax = plt.

subplot2grid((6, 1), (0, 0), rowspan=2, colspan=1) # Create bottom subplot for shared price/volume axis self.

price_ax = plt.

subplot2grid((6, 1), (2, 0), rowspan=8, colspan=1, sharex=self.

net_worth_ax) # Create a new axis for volume which shares its x-axis with price self.

volume_ax = self.


twinx() # Add padding to make graph easier to view plt.


11, bottom=0.

24, right=0.

90, top=0.

90, wspace=0.

2, hspace=0) # Show the graph without blocking the rest of the program plt.

show(block=False)We use the plt.


) method to first create a subplot at the top of our figure to render our net worth grid, and then create another subplot below it for our price grid.

The first argument of subplot2grid is the size of the subplot and the second is the location within the figure.

To render our trade volume bars, we call the twinx() method on self.

price_ax, which allows us to overlay another grid on top that will share the same x-axis.

Finally, and most importantly, we will render our figure to the screen using plt.


If you forget to pass block=False, you will only ever see the first step rendered, after which the agent will be blocked from continuing.

Next, let’s write our render method.

This will take all of the information from the current time step and render a live representation to the screen.

def render(self, current_step, net_worth, trades, window_size=40): self.

net_worths[current_step] = net_worth window_start = max(current_step – window_size, 0) step_range = range(window_start, current_step + 1) # Format dates as timestamps, necessary for candlestick graph dates = np.

array([date2num(x) for x in self.


values[step_range]]) self.

_render_net_worth(current_step, net_worth, window_size, dates) self.

_render_price(current_step, net_worth, dates, step_range) self.

_render_volume(current_step, net_worth, dates, step_range) self.

_render_trades(current_step, trades, step_range) # Format the date ticks to be more easily read self.




values[step_range], rotation=45, horizontalalignment='right') # Hide duplicate net worth date labels plt.



get_xticklabels(), visible=False) # Necessary to view frames before they are unrendered plt.


001)Here, we save the net_worth, and then render each graph from top to bottom.

We’re also going to annotate the price graph with the trades the agent has taken in theself.

render_trades method.

It’s important to call plt.

pause() here, otherwise each frame will be cleared by the next call to render, before the last frame was actually shown on screen.

Now, let’s look at each of the graph’s render methods, starting with net worth.

def _render_net_worth(self, current_step, net_worth, step_range, dates): # Clear the frame rendered last step self.


clear() # Plot net worths self.


plot_date(dates, self.

net_worths[step_range], '- ', label='Net Worth') # Show legend, which uses the label we defined for the plot above self.


legend() legend = self.


legend(loc=2, ncol=2, prop={'size': 8}) legend.



4) last_date = date2num(self.


values[current_step]) last_net_worth = self.

net_worths[current_step] # Annotate the current net worth on the net worth graph self.




format(net_worth), (last_date, last_net_worth), xytext=(last_date, last_net_worth), bbox=dict(boxstyle='round', fc='w', ec='k', lw=1), color="black", fontsize="small") # Add space above and below min/max net worth self.


set_ylim( min(self.



net_worths)]) / 1.

25, max(self.

net_worths) * 1.

25)We just call plot_date(.

) on our net worth subplot to plot a simple line graph, then annotate it with the agent’s current net_worth, and add a legend.

Rendering the price graph is a bit more complicated.

To keep things simple, we are going to render the OHCL bars in a separate method from the volume bars.

First, you need to pip install mpl_finance if you don’t already have it, as this package is needed for the candlestick graphs we’ll be using.

Then add this line to the top of your file.

from mpl_finance import candlestick_ochl as candlestickGreat, let’s clear the previous frame, zip up the OHCL data, and render a candlestick graph to the self.

price_ax subplot.

def _render_price(self, current_step, net_worth, dates, step_range): self.


clear() # Format data for OHCL candlestick graph candlesticks = zip(dates, self.


values[step_range], self.


values[step_range], self.


values[step_range], self.


values[step_range]) # Plot price using candlestick graph from mpl_finance candlestick(self.

price_ax, candlesticks, width=1, colorup=UP_COLOR, colordown=DOWN_COLOR) last_date = date2num(self.


values[current_step]) last_close = self.


values[current_step] last_high = self.


values[current_step] # Print the current price to the price axis self.




format(last_close), (last_date, last_close), xytext=(last_date, last_high), bbox=dict(boxstyle='round', fc='w', ec='k', lw=1), color="black", fontsize="small") # Shift price axis up to give volume chart space ylim = self.


get_ylim() self.


set_ylim(ylim[0] – (ylim[1] – ylim[0]) * VOLUME_CHART_HEIGHT, ylim[1])We’ve annotated the graph with the stock’s current price, and shifted the chart up to prevent it from overlapping with the volume bars.

Next let’s look at the volume render method, which is much simpler as there are no annotations.

def _render_volume(self, current_step, net_worth, dates, step_range): self.


clear() volume = np.



values[step_range]) pos = self.


values[step_range] – self.


values[step_range] < 0 neg = self.


values[step_range] – self.


values[step_range] > 0 # Color volume bars based on price direction on that date self.


bar(dates[pos], volume[pos], color=UP_COLOR, alpha=0.

4, width=1, align='center') self.


bar(dates[neg], volume[neg], color=DOWN_COLOR, alpha=0.

4, width=1, align='center') # Cap volume axis height below price chart and hide ticks self.


set_ylim(0, max(volume) / VOLUME_CHART_HEIGHT) self.



set_ticks([])Just a simple bar graph, with each bar colored either green or red, depending on whether the price moved up or down in that time step.

Finally, let’s get to the fun part: _render_trades.

In this method, we’ll render an arrow on the price graph where the agent has made a trade, annotated with the total amount transacted.

def _render_trades(self, current_step, trades, step_range): for trade in trades: if trade['step'] in step_range: date = date2num(self.


values[trade['step']]) high = self.


values[trade['step']] low = self.


values[trade['step']] if trade['type'] == 'buy': high_low = low color = UP_TEXT_COLOR else: high_low = high color = DOWN_TEXT_COLOR total = '{0:.


format(trade['total']) # Print the current price to the price axis self.


annotate(f'${total}', (date, high_low), xytext=(date, high_low), color=color, fontsize=8, arrowprops=(dict(color=color)))And that’s it.We now have a beautiful, live rendering visualization of our stock trading environment we created in the last article.It’s too bad we still haven’t put much time into teaching the agent how to make money… We’ll leave that for next time!Thanks for reading.Let me know if there’s any questions you have about this article or anything you’d like me to write about next.If you want to take a look at the code, it’s all available on my Github.

Make sure you’ve followed me on Medium, because next week we’ll be building off the last two articles to create a profitable cryptocurrency trading agent.

.. More details

Leave a Reply