Source code for csle_common.dao.training.mixed_multi_threshold_stopping_policy

from typing import List, Dict, Tuple, Union, Optional, Any
import random
import numpy as np
import csle_common.constants.constants as constants
from csle_common.dao.training.policy import Policy
from csle_common.dao.training.agent_type import AgentType
from csle_common.dao.simulation_config.state import State
from csle_common.dao.training.player_type import PlayerType
from csle_common.dao.simulation_config.action import Action
from csle_common.dao.training.experiment_config import ExperimentConfig
from csle_common.dao.simulation_config.state_type import StateType
from csle_common.dao.training.policy_type import PolicyType
from csle_common.dao.training.multi_threshold_stopping_policy import MultiThresholdStoppingPolicy


[docs]class MixedMultiThresholdStoppingPolicy(Policy): """ A mixed multi-threshold stopping policy """ def __init__(self, defender_Theta: List[List[List[float]]], attacker_Theta: List[List[List[List[float]]]], simulation_name: str, L: int, states: List[State], player_type: PlayerType, actions: List[Action], experiment_config: Optional[ExperimentConfig], avg_R: float, agent_type: AgentType, opponent_strategy: Optional["MixedMultiThresholdStoppingPolicy"] = None): """ Initializes the policy :param defender_Theta: the buffer with threshold vectors for the defender :param attacker_Theta: the buffer with threshold vectors for the attacker :param simulation_name: the simulation name :param attacker: whether it is an attacker or not :param L: the number of stop actions :param states: list of states (required for computing stage policies) :param actions: list of actions :param experiment_config: the experiment configuration used for training the policy :param avg_R: the average reward of the policy when evaluated in the simulation :param agent_type: the agent type :param opponent_strategy: the opponent's strategy """ super(MixedMultiThresholdStoppingPolicy, self).__init__(agent_type=agent_type, player_type=player_type) self.attacker_Theta = attacker_Theta self.defender_Theta = defender_Theta self.id = -1 self.simulation_name = simulation_name self.L = L self.states = states self.actions = actions self.experiment_config = experiment_config self.avg_R = avg_R self.opponent_strategy = opponent_strategy self.policy_type = PolicyType.MIXED_MULTI_THRESHOLD
[docs] def probability(self, o: List[float], a: int) -> int: """ Probability of a given action :param o: the current observation :param a: a given action :return: the probability of a """ return self.action(o=o) == a
[docs] def action(self, o: List[float], deterministic: bool = True) -> int: """ Multi-threshold stopping policy :param o: the current observation :param deterministic: boolean flag indicating whether the action selection should be deterministic :return: the selected action """ if not self.player_type == PlayerType.ATTACKER: a, _ = self._defender_action(o=o) return a else: if self.opponent_strategy is None: raise ValueError("Must specify the opponent strategy") a1, defender_stop_probability = self.opponent_strategy._defender_action(o=o) if a1 == 0: defender_stop_probability = 1 - defender_stop_probability a, _ = self._attacker_action(o=o, defender_stop_probability=defender_stop_probability) return a
def _attacker_action(self, o: List[float], defender_stop_probability: float) -> Tuple[int, float]: """ Multi-threshold stopping policy of the attacker :param o: the input observation :param defender_stop_probability: the defender's stopping probability :return: the selected action (int) and the probability of selecting that action """ s = int(o[2]) if s == 2: return 0, 1.0 l = int(o[0]) thresholds = self.attacker_Theta[s][l - 1][0] counts = self.attacker_Theta[s][l - 1][1] mixture_weights = np.array(counts) / sum(self.attacker_Theta[s][l - 1][1]) prob = 0 for i in range(len(thresholds)): if defender_stop_probability >= thresholds[i]: prob += mixture_weights[i] * MultiThresholdStoppingPolicy.stopping_probability( b1=defender_stop_probability, threshold=thresholds[i], k=-20) if s == 1: a = 0 if random.uniform(0, 1) < prob: a = 1 else: prob = 1 - prob else: a = 1 if random.uniform(0, 1) < prob: a = 0 else: prob = 1 - prob prob = round(prob, 3) return a, prob def _defender_action(self, o) -> Tuple[int, float]: """ Multi-threshold stopping policy of the defender :param o: the input observation :return: the selected action (int), and the probability of selecting that action """ b1 = o[1] l = int(o[0]) thresholds = self.defender_Theta[l - 1][0] counts = self.defender_Theta[l - 1][1] mixture_weights = np.array(counts) / sum(self.defender_Theta[l - 1][1]) stop_probability = 0 for i, thresh in enumerate(thresholds): if b1 >= thresh: stop_probability += mixture_weights[i] * MultiThresholdStoppingPolicy.stopping_probability( b1=b1, threshold=thresh, k=-20) stop_probability = round(stop_probability, 3) if random.uniform(0, 1) >= stop_probability: return 0, 1 - stop_probability else: return 1, stop_probability
[docs] def to_dict(self) -> Dict[str, Any]: """ :return: a dict representation of the policy """ d: Dict[str, Any] = {} d["attacker_Theta"] = self.attacker_Theta d["defender_Theta"] = self.attacker_Theta if self.player_type == PlayerType.ATTACKER: d["Theta"] = self.attacker_Theta if self.player_type == PlayerType.DEFENDER: d["Theta"] = self.defender_Theta d["id"] = self.id d["simulation_name"] = self.simulation_name d["states"] = list(map(lambda x: x.to_dict(), self.states)) d["actions"] = list(map(lambda x: x.to_dict(), self.actions)) d["player_type"] = self.player_type d["agent_type"] = self.agent_type d["L"] = self.L if self.experiment_config is not None: d["experiment_config"] = self.experiment_config.to_dict() else: d["experiment_config"] = None d["avg_R"] = self.avg_R d["policy_type"] = self.policy_type return d
[docs] @staticmethod def from_dict(d: Dict[str, Any]) -> "MixedMultiThresholdStoppingPolicy": """ Converst a dict representation of the object to an instance :param d: the dict to convert :return: the created instance """ obj = MixedMultiThresholdStoppingPolicy( defender_Theta=d["defender_Theta"], attacker_Theta=d["attacker_Theta"], simulation_name=d["simulation_name"], L=d["L"], states=list(map(lambda x: State.from_dict(x), d["states"])), player_type=d["player_type"], actions=list(map(lambda x: Action.from_dict(x), d["actions"])), experiment_config=ExperimentConfig.from_dict(d["experiment_config"]), avg_R=d["avg_R"], agent_type=d["agent_type"]) obj.id = d["id"] return obj
def _update_Theta_attacker(self, new_thresholds: List[List[List[float]]]) -> None: """ Updates the Theta buffer with new attacker thresholds :param new_thresholds: the new thresholds to add to the buffer :return: None """ for a_thresholds in new_thresholds: for s in range(2): for l in range(self.L): if len(self.attacker_Theta[s][l]) == 0: self.attacker_Theta[s][l] = [[a_thresholds[s][l]], [1]] else: if a_thresholds[s][l] in self.attacker_Theta[s][l][0]: i = self.attacker_Theta[s][l][0].index(a_thresholds[s][l]) self.attacker_Theta[s][l][1][i] = self.attacker_Theta[s][l][1][i] + 1 else: self.attacker_Theta[s][l][0].append(a_thresholds[s][l]) self.attacker_Theta[s][l][1].append(1) def _update_Theta_defender(self, new_thresholds: List[List[float]]) -> None: """ Updates the theta buffer with new defender thresholds :param new_thresholds: the new thresholds to add to the buffer :return: None """ for defender_theta in new_thresholds: for l in range(self.L): if len(self.defender_Theta[l]) == 0: self.defender_Theta[l] = [[defender_theta[l]], [1]] else: if defender_theta[l] in self.defender_Theta[l][0]: i = self.defender_Theta[l][0].index(defender_theta[l]) self.defender_Theta[l][1][i] = self.defender_Theta[l][1][i] + 1 else: self.defender_Theta[l][0].append(defender_theta[l]) self.defender_Theta[l][1].append(1)
[docs] def stage_policy(self, o: List[Union[int, float]]) -> List[List[float]]: """ Returns the stage policy for a given observation :param o: the observation to return the stage policy for :return: the stage policy """ b1 = o[1] l = int(o[0]) if not self.player_type == PlayerType.ATTACKER: stage_policy = [] for _ in self.states: a1, stopping_probability = self._defender_action(o=o) if a1 == 0: stopping_probability = 1 - stopping_probability stage_policy.append([1 - stopping_probability, stopping_probability]) return stage_policy else: stage_policy = [] if self.opponent_strategy is None: raise ValueError("Opponent strategy is None and must be specified.") a1, defender_stop_probability = self.opponent_strategy._defender_action(o=o) if a1 == 0: defender_stop_probability = 1 - defender_stop_probability for s in self.states: if s.state_type != StateType.TERMINAL: o = [l, b1, s.id] a, action_prob = self._attacker_action(o=o, defender_stop_probability=defender_stop_probability) if action_prob == 1: action_prob = 0.99 if action_prob == 0: action_prob = 0.01 if a == 1: stage_policy.append([1 - action_prob, action_prob]) elif a == 0: stage_policy.append([action_prob, 1 - action_prob]) else: raise ValueError(f"Invalid state: {s.id}, valid states are: 0 and 1") else: stage_policy.append([0.5, 0.5]) return stage_policy
[docs] def stop_distributions(self) -> Dict[str, List[float]]: """ :return: the stop distributions and their names """ distributions: Dict[str, List[float]] = {} if self.player_type == PlayerType.DEFENDER: belief_space = np.linspace(0, 1, num=100) for l in range(1, self.L + 1): stop_dist: List[float] = [] for b in belief_space: a1, prob = self._defender_action(o=[l, b]) if a1 == 1: stop_dist.append(round(prob, 3)) else: stop_dist.append(round(1 - prob, 3)) distributions[constants.T_SPSA.STOP_DISTRIBUTION_DEFENDER + f"_l={l}"] = stop_dist else: defender_stop_space = np.linspace(0, 1, num=100) for s in self.states: if s.state_type != StateType.TERMINAL: for l in range(1, self.L + 1): stop_dist = [] for pi_1_S in defender_stop_space: a2, prob = self._attacker_action(o=[l, pi_1_S, s.id], defender_stop_probability=pi_1_S) if a2 == 1: stop_dist.append(round(prob, 3)) else: stop_dist.append(round(1 - prob, 3)) distributions[constants.T_SPSA.STOP_DISTRIBUTION_ATTACKER + f"_l={l}_s={s.id}"] = stop_dist return distributions
def __str__(self) -> str: """ :return: a string representation of the object """ return f"defender_Theta: {self.defender_Theta}, attacker_theta: {self.attacker_Theta}" \ f"id: {self.id}, simulation_name: {self.simulation_name}, " \ f"player_type: {self.player_type}, " \ f"L:{self.L}, states: {self.states}, agent_type: {self.agent_type}, actions: {self.actions}," \ f"experiment_config: {self.experiment_config}, avg_R: {self.avg_R}, policy_type: {self.policy_type}"
[docs] @staticmethod def from_json_file(json_file_path: str) -> "MixedMultiThresholdStoppingPolicy": """ Reads a json file and converts it to a DTO :param json_file_path: the json file path :return: the converted DTO """ import io import json with io.open(json_file_path, 'r') as f: json_str = f.read() return MixedMultiThresholdStoppingPolicy.from_dict(json.loads(json_str))
[docs] def copy(self) -> "MixedMultiThresholdStoppingPolicy": """ :return: a copy of the DTO """ return self.from_dict(self.to_dict())
def __eq__(self, other: object) -> bool: """ Compares equality with another object :param other: the object to compare with :return: True if equals, False otherwise """ if not isinstance(other, MixedMultiThresholdStoppingPolicy): return False else: d1 = self.to_dict() d2 = other.to_dict() del d1["Theta"] del d2["Theta"] return d1 == d2