-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathplot_utils.py
68 lines (58 loc) · 2.29 KB
/
plot_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
def plot_blackjack_values(V):
def get_Z(x, y, usable_ace):
if (x,y,usable_ace) in V:
return V[x,y,usable_ace]
else:
return 0
def get_figure(usable_ace, ax):
x_range = np.arange(11, 22)
y_range = np.arange(1, 11)
X, Y = np.meshgrid(x_range, y_range)
Z = np.array([get_Z(x,y,usable_ace) for x,y in zip(np.ravel(X), np.ravel(Y))]).reshape(X.shape)
surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=plt.cm.coolwarm, vmin=-1.0, vmax=1.0)
ax.set_xlabel('Player\'s Current Sum')
ax.set_ylabel('Dealer\'s Showing Card')
ax.set_zlabel('State Value')
ax.view_init(ax.elev, -120)
fig = plt.figure(figsize=(20, 20))
ax = fig.add_subplot(211, projection='3d')
ax.set_title('Usable Ace')
get_figure(True, ax)
ax = fig.add_subplot(212, projection='3d')
ax.set_title('No Usable Ace')
get_figure(False, ax)
plt.show()
def plot_policy(policy):
def get_Z(x, y, usable_ace):
if (x,y,usable_ace) in policy:
return policy[x,y,usable_ace]
else:
return 1
def get_figure(usable_ace, ax):
x_range = np.arange(11, 22)
y_range = np.arange(10, 0, -1)
X, Y = np.meshgrid(x_range, y_range)
Z = np.array([[get_Z(x,y,usable_ace) for x in x_range] for y in y_range])
surf = ax.imshow(Z, cmap=plt.get_cmap('Pastel2', 2), vmin=0, vmax=1, extent=[10.5, 21.5, 0.5, 10.5])
plt.xticks(x_range)
plt.yticks(y_range)
plt.gca().invert_yaxis()
ax.set_xlabel('Player\'s Current Sum')
ax.set_ylabel('Dealer\'s Showing Card')
ax.grid(color='w', linestyle='-', linewidth=1)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1)
cbar = plt.colorbar(surf, ticks=[0,1], cax=cax)
cbar.ax.set_yticklabels(['0 (STICK)','1 (HIT)'])
fig = plt.figure(figsize=(15, 15))
ax = fig.add_subplot(121)
ax.set_title('Usable Ace')
get_figure(True, ax)
ax = fig.add_subplot(122)
ax.set_title('No Usable Ace')
get_figure(False, ax)
plt.show()