
【強化学習#10】DQN
記事の目的
youtubeの「【強化学習#10】DQN」で解説した内容のコードです。
目次
1 環境とエージェント
x
4
1
import numpy as np
2
import matplotlib.pyplot as plt
3
import seaborn as sns
4
np.random.seed(1)
1
5
1
from torch import nn
2
from torch import optim
3
import torch
4
import numpy as np
5
torch.manual_seed(1)
1
36
36
1
class Environment:
2
3
def __init__(self, size=3, lucky=[]):
4
5
self.size = size
6
self.lucky = lucky
7
self.goal = (size-1, size-1)
8
self.states = [(x, y) for x in range(size) for y in range(size)]
9
10
def next_state(self, s, a):
11
12
s_next = (s[0] + a[0], s[1] + a[1])
13
14
if s == self.goal:
15
return s
16
17
if s_next not in self.states:
18
return s
19
20
if s_next in self.lucky:
21
if np.random.random() < 0.8:
22
return self.goal
23
else:
24
return s_next
25
26
return s_next
27
28
def reward(self, s, s_next):
29
30
if s == self.goal:
31
return -1
32
33
if s_next == self.goal:
34
return 0
35
36
return -1
1
13
13
1
class Agent():
2
3
def __init__(self, environment):
4
5
self.actions = [(-1, 0), (0, -1), (1, 0), (0, 1)]
6
self.environment = environment
7
8
def action(self, s, a, prob=False):
9
10
s_next = self.environment.next_state(s, a)
11
r = self.environment.reward(s, s_next)
12
13
return r, s_next
2 DQN
1
31
31
1
class NN:
2
def __init__(self, agent):
3
self.model = self.model()
4
self.criterion = nn.MSELoss()
5
self.actions = agent.actions
6
7
def model(self):
8
model = nn.Sequential()
9
model.add_module('fc1', nn.Linear(4, 16))
10
model.add_module('relu1', nn.ReLU())
11
model.add_module('fc2', nn.Linear(16, 8))
12
model.add_module('relu1', nn.ReLU())
13
model.add_module('fc3', nn.Linear(8, 1))
14
self.optimizer = optim.Adam(model.parameters())
15
return model
16
17
def train_model(self, sa, labels, num_train=1000):
18
for _ in range(num_train):
19
qvalue = self.model(torch.tensor(sa).float())
20
loss = self.criterion(qvalue, torch.tensor(labels).float())
21
self.optimizer.zero_grad()
22
loss.backward()
23
self.optimizer.step()
24
25
def q_max(self, state):
26
sa = []
27
for action in self.actions:
28
sa.append(state+action)
29
q = self.model(torch.tensor([np.array(sa)]).float()).detach()
30
a_max = np.argmax(q)
31
return self.actions[a_max], q[0,a_max,0]
1
20
20
1
def get_episode(agent, nn_model, epsilon=0.1):
2
3
s = agent.environment.states[np.random.randint(agent.environment.size**2-1)]
4
5
episode = []
6
while True:
7
8
if np.random.random() < epsilon:
9
a = agent.actions[np.random.randint(2,4)]
10
else:
11
a, _ = nn_model.q_max(s)
12
13
r, s_next = agent.action(s, a)
14
15
episode.append((s, a, r, s_next))
16
if s_next == agent.environment.goal:
17
break
18
s = s_next
19
20
return episode
1
21
21
1
def train(agent, nn_model, epsilon=0.1, num=100, num_train=1000):
2
for c in range(num):
3
print(f'num : {c+1} ')
4
5
examples = []
6
for _ in range(100):
7
episode = get_episode(agent, nn_model, epsilon)
8
examples += episode
9
np.random.shuffle(examples)
10
11
sa = []
12
labels = []
13
for s, a, r, s_next in examples:
14
sa.append(s+a)
15
_, q_next = nn_model.q_max(s_next)
16
labels.append([r + q_next.detach()])
17
18
nn_model.train_model(sa, labels, num_train)
19
20
show_values(agent, model1)
21
show_policy(agent, model1)
3 可視化用関数
1
18
18
1
def show_maze(environment):
2
size = environment.size
3
fig = plt.figure(figsize=(3,3))
4
5
plt.plot([-0.5, -0.5], [-0.5, size-0.5], color='k')
6
plt.plot([-0.5, size-0.5], [size-0.5, size-0.5], color='k')
7
plt.plot([size-0.5, -0.5], [-0.5, -0.5], color='k')
8
plt.plot([size-0.5, size-0.5], [size-0.5, -0.5], color='k')
9
10
for i in range(size):
11
for j in range(size):
12
plt.text(i, j, "{}".format(i+size*j), size=20, ha="center", va="center")
13
if (i,j) in environment.lucky:
14
x = np.array([i-0.5,i-0.5,i+0.5,i+0.5])
15
y = np.array([j-0.5,j+0.5,j+0.5,j-0.5])
16
plt.fill(x,y, color="lightgreen")
17
18
plt.axis("off")
1
10
10
1
def show_values(agent, nn_model):
2
3
fig = plt.figure(figsize=(3,3))
4
result = np.zeros([agent.environment.size, agent.environment.size])
5
for (x, y) in agent.environment.states:
6
a_max, q_max = nn_model.q_max((x, y))
7
result[y][x] = q_max
8
9
sns.heatmap(result, square=True, cbar=False, annot=True, fmt='3.2f', cmap='autumn_r').invert_yaxis()
10
plt.axis("off")
1
30
30
1
def show_policy(agent, nn_model):
2
size = agent.environment.size
3
fig = plt.figure(figsize=(3,3))
4
5
plt.plot([-0.5, -0.5], [-0.5, size-0.5], color='k')
6
plt.plot([-0.5, size-0.5], [size-0.5, size-0.5], color='k')
7
plt.plot([size-0.5, -0.5], [-0.5, -0.5], color='k')
8
plt.plot([size-0.5, size-0.5], [size-0.5, -0.5], color='k')
9
10
for i in range(size):
11
for j in range(size):
12
if (i,j) in agent.environment.lucky:
13
x = np.array([i-0.5,i-0.5,i+0.5,i+0.5])
14
y = np.array([j-0.5,j+0.5,j+0.5,j-0.5])
15
plt.fill(x,y, color="lightgreen")
16
17
rotation = {(-1, 0): 180, (0, 1): 90, (1, 0): 0, (0, -1): 270}
18
for s in agent.environment.states:
19
if s == agent.environment.goal:
20
direction=None
21
else:
22
a_max, q_max = nn_model.q_max(s)
23
direction = rotation[a_max]
24
25
if direction != None:
26
bbox_props = dict(boxstyle='rarrow')
27
plt.text(s[0], s[1], ' ', bbox=bbox_props, size=8,
28
ha='center', va='center', rotation=direction)
29
30
plt.axis("off")
4 シミュレーション
1
3
1
env1 = Environment(size=4, lucky=[(1,2), (2,3)])
2
agent1 = Agent(env1)
3
show_maze(env1)
1
1
1
train(agent1, num=10000)