【強化学習#10】DQN
記事の目的
youtubeの「【強化学習#10】DQN」で解説した内容のコードです。
目次
1 環境とエージェント
import numpy as np import matplotlib.pyplot as plt import seaborn as sns np.random.seed(1)
from torch import nn from torch import optim import torch import numpy as np torch.manual_seed(1)
class Environment: def __init__(self, size=3, lucky=[]): self.size = size self.lucky = lucky self.goal = (size-1, size-1) self.states = [(x, y) for x in range(size) for y in range(size)] def next_state(self, s, a): s_next = (s[0] + a[0], s[1] + a[1]) if s == self.goal: return s if s_next not in self.states: return s if s_next in self.lucky: if np.random.random() < 0.8: return self.goal else: return s_next return s_next def reward(self, s, s_next): if s == self.goal: return -1 if s_next == self.goal: return 0 return -1
class Agent(): def __init__(self, environment): self.actions = [(-1, 0), (0, -1), (1, 0), (0, 1)] self.environment = environment def action(self, s, a, prob=False): s_next = self.environment.next_state(s, a) r = self.environment.reward(s, s_next) return r, s_next
2 DQN
class NN: def __init__(self, agent): self.model = self.model() self.criterion = nn.MSELoss() self.actions = agent.actions def model(self): model = nn.Sequential() model.add_module('fc1', nn.Linear(4, 16)) model.add_module('relu1', nn.ReLU()) model.add_module('fc2', nn.Linear(16, 8)) model.add_module('relu1', nn.ReLU()) model.add_module('fc3', nn.Linear(8, 1)) self.optimizer = optim.Adam(model.parameters()) return model def train_model(self, sa, labels, num_train=1000): for _ in range(num_train): qvalue = self.model(torch.tensor(sa).float()) loss = self.criterion(qvalue, torch.tensor(labels).float()) self.optimizer.zero_grad() loss.backward() self.optimizer.step() def q_max(self, state): sa = [] for action in self.actions: sa.append(state+action) q = self.model(torch.tensor([np.array(sa)]).float()).detach() a_max = np.argmax(q) return self.actions[a_max], q[0,a_max,0]
def get_episode(agent, nn_model, epsilon=0.1): s = agent.environment.states[np.random.randint(agent.environment.size**2-1)] episode = [] while True: if np.random.random() < epsilon: a = agent.actions[np.random.randint(2,4)] else: a, _ = nn_model.q_max(s) r, s_next = agent.action(s, a) episode.append((s, a, r, s_next)) if s_next == agent.environment.goal: break s = s_next return episode
def train(agent, nn_model, epsilon=0.1, num=100, num_train=1000): for c in range(num): print(f'num : {c+1} ') examples = [] for _ in range(100): episode = get_episode(agent, nn_model, epsilon) examples += episode np.random.shuffle(examples) sa = [] labels = [] for s, a, r, s_next in examples: sa.append(s+a) _, q_next = nn_model.q_max(s_next) labels.append([r + q_next.detach()]) nn_model.train_model(sa, labels, num_train) show_values(agent, model1) show_policy(agent, model1)
3 可視化用関数
def show_maze(environment): size = environment.size fig = plt.figure(figsize=(3,3)) plt.plot([-0.5, -0.5], [-0.5, size-0.5], color='k') plt.plot([-0.5, size-0.5], [size-0.5, size-0.5], color='k') plt.plot([size-0.5, -0.5], [-0.5, -0.5], color='k') plt.plot([size-0.5, size-0.5], [size-0.5, -0.5], color='k') for i in range(size): for j in range(size): plt.text(i, j, "{}".format(i+size*j), size=20, ha="center", va="center") if (i,j) in environment.lucky: x = np.array([i-0.5,i-0.5,i+0.5,i+0.5]) y = np.array([j-0.5,j+0.5,j+0.5,j-0.5]) plt.fill(x,y, color="lightgreen") plt.axis("off")
def show_values(agent, nn_model): fig = plt.figure(figsize=(3,3)) result = np.zeros([agent.environment.size, agent.environment.size]) for (x, y) in agent.environment.states: a_max, q_max = nn_model.q_max((x, y)) result[y][x] = q_max sns.heatmap(result, square=True, cbar=False, annot=True, fmt='3.2f', cmap='autumn_r').invert_yaxis() plt.axis("off")
def show_policy(agent, nn_model): size = agent.environment.size fig = plt.figure(figsize=(3,3)) plt.plot([-0.5, -0.5], [-0.5, size-0.5], color='k') plt.plot([-0.5, size-0.5], [size-0.5, size-0.5], color='k') plt.plot([size-0.5, -0.5], [-0.5, -0.5], color='k') plt.plot([size-0.5, size-0.5], [size-0.5, -0.5], color='k') for i in range(size): for j in range(size): if (i,j) in agent.environment.lucky: x = np.array([i-0.5,i-0.5,i+0.5,i+0.5]) y = np.array([j-0.5,j+0.5,j+0.5,j-0.5]) plt.fill(x,y, color="lightgreen") rotation = {(-1, 0): 180, (0, 1): 90, (1, 0): 0, (0, -1): 270} for s in agent.environment.states: if s == agent.environment.goal: direction=None else: a_max, q_max = nn_model.q_max(s) direction = rotation[a_max] if direction != None: bbox_props = dict(boxstyle='rarrow') plt.text(s[0], s[1], ' ', bbox=bbox_props, size=8, ha='center', va='center', rotation=direction) plt.axis("off")
4 シミュレーション
env1 = Environment(size=4, lucky=[(1,2), (2,3)]) agent1 = Agent(env1) show_maze(env1)
train(agent1, num=10000)