【強化学習#5】価値反復法

記事の目的

youtubeの「【強化学習#5】価値反復法」で解説した内容のコードです。

 

目次

  1. 環境とエージェント
  2. 価値反復法
  3. 可視化用関数
  4. シミュレーション

 

1 環境とエージェント

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
np.random.seed(1)
class Environment:
    
    def __init__(self, size=3, lucky=[]):
        
        self.size = size
        self.lucky = lucky
        self.goal = (size-1, size-1)
        self.states = [(x, y) for x in range(size) for y in range(size)]
        
        self.value = {}
        for s in self.states:
            self.value[s] = 0
        
    def next_state(self, s, a):
        
        s_next = (s[0] + a[0], s[1] + a[1])
        
        if s == self.goal:
            return [(1, s)]
        
        if s_next not in self.states:
            return [(1, s)]
        
        if s_next in self.lucky:
            return [(0.8, self.goal), (0.2, s_next)]
        
        return [(1, s_next)]
    
    def reward(self, s, s_next):
        
        if s == self.goal:
            return 0
        
        if s_next == self.goal:
            return 1
        
        return 0
class Agent():
    
    def __init__(self, environment, policy=[0, 0, 1/2, 1/2]):
        
        self.actions = [(-1, 0), (0, -1), (1, 0), (0, 1)]
        self.environment = environment
        
        self.policy = {}
        for s in self.environment.states:
            for i, a  in enumerate(self.actions):
                self.policy[(s, a)] = policy[i]

 

2 価値反復法

def policy(agent, s, gamma=0.5):
    
    q_max = 0
    a_max = None
    for a in agent.actions:
        q = 0
        for p, s_next in agent.environment.next_state(s, a):
            r = agent.environment.reward(s, s_next)
            q += p*(r+gamma*agent.environment.value[s_next])
        if q>q_max:
            q_max = q
            a_max = a

    for a in agent.actions:
        agent.policy[(s, a)] = 0
        agent.policy[(s, a_max)] = 1
        
    return q_max
def train(agent, gamma=0.5):
    
    epoch = 0
    while True:

        policy_pre = agent.policy.copy()
        for s in agent.environment.states:
            agent.environment.value[s] = policy(agent, s, gamma)
            
        epoch += 1
        print(f'epoch : {epoch} ')
        
        if policy_pre == agent.policy:
            show_values(agent)
            show_policy(agent)
            break

 

3 可視化用関数

def show_maze(environment):
    size = environment.size
    fig = plt.figure(figsize=(3,3))

    plt.plot([-0.5, -0.5], [-0.5, size-0.5], color='k')
    plt.plot([-0.5, size-0.5], [size-0.5, size-0.5], color='k')
    plt.plot([size-0.5, -0.5], [-0.5, -0.5], color='k')
    plt.plot([size-0.5, size-0.5], [size-0.5, -0.5], color='k')
    
    for i in range(size):
        for j in range(size):
            plt.text(i, j, "{}".format(i+size*j), size=20, ha="center", va="center")
            if (i,j) in environment.lucky:
                x = np.array([i-0.5,i-0.5,i+0.5,i+0.5])
                y = np.array([j-0.5,j+0.5,j+0.5,j-0.5])
                plt.fill(x,y, color="lightgreen")

    plt.axis("off")
def show_values(agent):

    fig = plt.figure(figsize=(3,3))
    result = np.zeros([agent.environment.size, agent.environment.size])
    for (x, y) in agent.environment.states:
        result[y][x]  = agent.environment.value[(x, y)]
        
    sns.heatmap(result, square=True, cbar=False, annot=True, fmt='3.2f', cmap='autumn_r').invert_yaxis()
    plt.axis("off")
 def show_policy(agent):
    size = agent.environment.size
    fig = plt.figure(figsize=(3,3))

    plt.plot([-0.5, -0.5], [-0.5, size-0.5], color='k')
    plt.plot([-0.5, size-0.5], [size-0.5, size-0.5], color='k')
    plt.plot([size-0.5, -0.5], [-0.5, -0.5], color='k')
    plt.plot([size-0.5, size-0.5], [size-0.5, -0.5], color='k')


    for i in range(size):
        for j in range(size):
            if (i,j) in agent.environment.lucky:
                x = np.array([i-0.5,i-0.5,i+0.5,i+0.5])
                y = np.array([j-0.5,j+0.5,j+0.5,j-0.5])
                plt.fill(x,y, color="lightgreen")

    rotation = {(-1, 0): 180, (0, 1): 90, (1, 0): 0, (0, -1): 270}
    for y in range(size):
            for x in range(size):
                for a in agent.actions:
                    if (x, y) == agent.environment.goal:
                        direction = None
                    else:
                        for a in agent.actions:
                            if agent.policy[((x, y), a)] == 1:
                                direction = rotation[a]
                    if direction != None:
                        bbox_props = dict(boxstyle='rarrow')
                        plt.text(x, y, '     ', bbox=bbox_props, size=8,
                                     ha='center', va='center', rotation=direction)
                        
    plt.axis("off")

 

4 シミュレーション

4.1 シミュレーション1

env1 = Environment(lucky=[(1,2)])
agent1 = Agent(env1)
show_maze(env1)

train(agent1)

 

4.2 シミュレーション2

env2 = Environment(size=4, lucky=[(1,2), (2,3)])
agent2 = Agent(env2)
show_maze(env2)

train(agent2)

 

4.3 シミュレーション3

env3 = Environment(size=4, lucky=[(1,2),(2,3)])
agent3 = Agent(env3, policy=[1/4, 1/4, 1/4,1/4])
show_maze(env3)

train(agent3)