Source code for csle_agents.agents.pomcp.action_node

from typing import List, Dict, Union
from csle_agents.agents.pomcp.node import Node


[docs]class ActionNode(Node): """ A node in the POMCP history tree where the last element of the history was an action """ def __init__(self, id: int, history: List[int], action: int, parent=None, value: float = -2000, visit_count: int = 0) -> None: """ Initializes the node :param id: the id of the node :param history: the history of the node :param action: the action of the node :param parent: the parent node :param value: the value of the node :param visit_count: the visit count of the node """ Node.__init__(self, id=id, history=history, parent=parent, value=value, visit_count=visit_count, action=action, observation=-1) self.mean_immediate_reward: float = 0.0 self.observation_to_node_map: Dict[int, Node] = {}
[docs] def update_stats(self, immediate_reward: float) -> None: """ Updates the mean return from the node by computing the rolling average :param immediate_reward: the latest reward sample :return: None """ self.mean_immediate_reward = (self.mean_immediate_reward * self.visit_count + immediate_reward) / \ (self.visit_count + 1)
[docs] def add_child(self, node: Node) -> None: """ Adds a child node to the tree. Since an action is always followed by an observation in the history, the next node will be an observation/belief node :param node: the new child node to add :return: None """ self.children.append(node) self.observation_to_node_map[node.observation] = node
[docs] def get_child(self, key: int) -> Union[None, Node]: """ Gets the child node corresponding to a specific observation :param key: the observation to get the node for :return: the child node or None if it was not found """ return self.observation_to_node_map.get(key, None)