from typing import Union, List, Dict, Optional, Any
import math
import random
import time
import gymnasium as gym
import os
import numpy as np
import numpy.typing as npt
import gym_csle_stopping_game.constants.constants as env_constants
from csle_common.dao.emulation_config.emulation_env_config import EmulationEnvConfig
from csle_common.dao.simulation_config.simulation_env_config import SimulationEnvConfig
from csle_common.dao.training.experiment_config import ExperimentConfig
from csle_common.dao.training.experiment_execution import ExperimentExecution
from csle_common.dao.training.experiment_result import ExperimentResult
from csle_common.dao.training.agent_type import AgentType
from csle_common.dao.training.player_type import PlayerType
from csle_common.util.experiment_util import ExperimentUtil
from csle_common.logging.log import Logger
from csle_common.dao.training.multi_threshold_stopping_policy import MultiThresholdStoppingPolicy
from csle_common.dao.training.linear_threshold_stopping_policy import LinearThresholdStoppingPolicy
from csle_common.dao.training.policy_type import PolicyType
from csle_common.metastore.metastore_facade import MetastoreFacade
from csle_common.dao.jobs.training_job_config import TrainingJobConfig
from csle_common.dao.simulation_config.base_env import BaseEnv
import csle_common.constants.constants as constants
from csle_common.util.general_util import GeneralUtil
from csle_agents.agents.base.base_agent import BaseAgent
import csle_agents.constants.constants as agents_constants
from csle_agents.common.objective_type import ObjectiveType
[docs]class TSPSAAgent(BaseAgent):
"""
RL Agent implementing the T-SPSA algorithm from
(Hammar, Stadler 2021 - Intrusion Prevention through Optimal Stopping))
"""
def __init__(self, simulation_env_config: SimulationEnvConfig,
emulation_env_config: Union[None, EmulationEnvConfig],
experiment_config: ExperimentConfig, env: Optional[BaseEnv] = None,
training_job: Optional[TrainingJobConfig] = None, save_to_metastore: bool = True):
"""
Initializes the TSPSA agent
:param simulation_env_config: the simulation env config
:param emulation_env_config: the emulation env config
:param experiment_config: the experiment config
:param env: (optional) the gym environment to use for simulation
:param training_job: (optional) a training job configuration
:param save_to_metastore: boolean flag that can be set to avoid saving results and progress to the metastore
"""
super().__init__(simulation_env_config=simulation_env_config, emulation_env_config=emulation_env_config,
experiment_config=experiment_config)
assert experiment_config.agent_type == AgentType.T_SPSA
self.env = env
self.training_job = training_job
self.save_to_metastore = save_to_metastore
[docs] def train(self) -> ExperimentExecution:
"""
Performs the policy training for the given random seeds using T-SPSA
:return: the training metrics and the trained policies
"""
pid = os.getpid()
# Initialize metrics
exp_result = ExperimentResult()
exp_result.plot_metrics.append(agents_constants.COMMON.AVERAGE_RETURN)
exp_result.plot_metrics.append(agents_constants.COMMON.RUNNING_AVERAGE_RETURN)
exp_result.plot_metrics.append(agents_constants.COMMON.WEIGHTED_INTRUSION_PREDICTION_DISTANCE)
exp_result.plot_metrics.append(agents_constants.COMMON.RUNNING_AVERAGE_WEIGHTED_INTRUSION_PREDICTION_DISTANCE)
exp_result.plot_metrics.append(agents_constants.COMMON.START_POINT_CORRECT)
exp_result.plot_metrics.append(agents_constants.COMMON.RUNNING_AVERAGE_START_POINT_CORRECT)
exp_result.plot_metrics.append(env_constants.ENV_METRICS.INTRUSION_LENGTH)
exp_result.plot_metrics.append(agents_constants.COMMON.RUNNING_AVERAGE_INTRUSION_LENGTH)
exp_result.plot_metrics.append(env_constants.ENV_METRICS.INTRUSION_START)
exp_result.plot_metrics.append(agents_constants.COMMON.RUNNING_AVERAGE_INTRUSION_START)
exp_result.plot_metrics.append(env_constants.ENV_METRICS.TIME_HORIZON)
exp_result.plot_metrics.append(agents_constants.COMMON.RUNNING_AVERAGE_TIME_HORIZON)
exp_result.plot_metrics.append(env_constants.ENV_METRICS.AVERAGE_UPPER_BOUND_RETURN)
exp_result.plot_metrics.append(env_constants.ENV_METRICS.AVERAGE_DEFENDER_BASELINE_STOP_ON_FIRST_ALERT_RETURN)
exp_result.plot_metrics.append(agents_constants.COMMON.RUNTIME)
for l in range(1, self.experiment_config.hparams[constants.T_SPSA.L].value + 1):
exp_result.plot_metrics.append(env_constants.ENV_METRICS.STOP + f"_{l}")
exp_result.plot_metrics.append(env_constants.ENV_METRICS.STOP + f"_running_average_{l}")
descr = f"Training of policies with the T-SPSA algorithm using " \
f"simulation:{self.simulation_env_config.name}"
for seed in self.experiment_config.random_seeds:
exp_result.all_metrics[seed] = {}
exp_result.all_metrics[seed][constants.T_SPSA.THETAS] = []
exp_result.all_metrics[seed][agents_constants.COMMON.AVERAGE_RETURN] = []
exp_result.all_metrics[seed][agents_constants.COMMON.RUNNING_AVERAGE_RETURN] = []
exp_result.all_metrics[seed][agents_constants.COMMON.WEIGHTED_INTRUSION_PREDICTION_DISTANCE] = []
exp_result.all_metrics[seed][
agents_constants.COMMON.RUNNING_AVERAGE_WEIGHTED_INTRUSION_PREDICTION_DISTANCE] = []
exp_result.all_metrics[seed][agents_constants.COMMON.START_POINT_CORRECT] = []
exp_result.all_metrics[seed][agents_constants.COMMON.RUNNING_AVERAGE_START_POINT_CORRECT] = []
exp_result.all_metrics[seed][constants.T_SPSA.THRESHOLDS] = []
if self.experiment_config.player_type == PlayerType.DEFENDER:
for l in range(1, self.experiment_config.hparams[constants.T_SPSA.L].value + 1):
exp_result.all_metrics[seed][constants.T_SPSA.STOP_DISTRIBUTION_DEFENDER + f"_l={l}"] = []
else:
for s in self.simulation_env_config.state_space_config.states:
for l in range(1, self.experiment_config.hparams[constants.T_SPSA.L].value + 1):
exp_result.all_metrics[seed][constants.T_SPSA.STOP_DISTRIBUTION_ATTACKER
+ f"_l={l}_s={s.id}"] = []
exp_result.all_metrics[seed][agents_constants.COMMON.RUNNING_AVERAGE_INTRUSION_START] = []
exp_result.all_metrics[seed][agents_constants.COMMON.RUNNING_AVERAGE_TIME_HORIZON] = []
exp_result.all_metrics[seed][agents_constants.COMMON.RUNNING_AVERAGE_INTRUSION_LENGTH] = []
exp_result.all_metrics[seed][env_constants.ENV_METRICS.INTRUSION_START] = []
exp_result.all_metrics[seed][env_constants.ENV_METRICS.INTRUSION_LENGTH] = []
exp_result.all_metrics[seed][env_constants.ENV_METRICS.TIME_HORIZON] = []
exp_result.all_metrics[seed][env_constants.ENV_METRICS.AVERAGE_UPPER_BOUND_RETURN] = []
exp_result.all_metrics[seed][
env_constants.ENV_METRICS.AVERAGE_DEFENDER_BASELINE_STOP_ON_FIRST_ALERT_RETURN] = []
for l in range(1, self.experiment_config.hparams[constants.T_SPSA.L].value + 1):
exp_result.all_metrics[seed][env_constants.ENV_METRICS.STOP + f"_{l}"] = []
exp_result.all_metrics[seed][env_constants.ENV_METRICS.STOP + f"_running_average_{l}"] = []
exp_result.all_metrics[seed][agents_constants.COMMON.RUNTIME] = []
# Initialize training job
if self.training_job is None:
emulation_env_name = ""
if self.emulation_env_config is not None:
emulation_env_name = self.emulation_env_config.name
self.training_job = TrainingJobConfig(
simulation_env_name=self.simulation_env_config.name, experiment_config=self.experiment_config,
progress_percentage=0, pid=pid, experiment_result=exp_result,
emulation_env_name=emulation_env_name, simulation_traces=[],
num_cached_traces=agents_constants.COMMON.NUM_CACHED_SIMULATION_TRACES,
log_file_path=Logger.__call__().get_log_file_path(), descr=descr,
physical_host_ip=GeneralUtil.get_host_ip())
if self.save_to_metastore:
training_job_id = MetastoreFacade.save_training_job(training_job=self.training_job)
self.training_job.id = training_job_id
else:
self.training_job.pid = pid
self.training_job.progress_percentage = 0
self.training_job.experiment_result = exp_result
if self.save_to_metastore:
MetastoreFacade.update_training_job(training_job=self.training_job, id=self.training_job.id)
# Initialize execution result
ts = time.time()
emulation_name = ""
if self.emulation_env_config is not None:
emulation_name = self.emulation_env_config.name
simulation_name = self.simulation_env_config.name
self.exp_execution = ExperimentExecution(
result=exp_result, config=self.experiment_config, timestamp=ts,
emulation_name=emulation_name, simulation_name=simulation_name,
descr=descr, log_file_path=self.training_job.log_file_path)
if self.save_to_metastore:
exp_execution_id = MetastoreFacade.save_experiment_execution(self.exp_execution)
self.exp_execution.id = exp_execution_id
config = self.simulation_env_config.simulation_env_input_config
if self.env is None:
self.env = gym.make(self.simulation_env_config.gym_env_name, config=config)
for seed in self.experiment_config.random_seeds:
ExperimentUtil.set_seed(seed)
exp_result = self.spsa(exp_result=exp_result, seed=seed, training_job=self.training_job,
random_seeds=self.experiment_config.random_seeds)
# Save latest trace
if self.save_to_metastore:
MetastoreFacade.save_simulation_trace(self.env.get_traces()[-1])
self.env.reset_traces()
# Calculate average and std metrics
exp_result.avg_metrics = {}
exp_result.std_metrics = {}
for metric in exp_result.all_metrics[self.experiment_config.random_seeds[0]].keys():
value_vectors = []
for seed in self.experiment_config.random_seeds:
value_vectors.append(exp_result.all_metrics[seed][metric])
avg_metrics = []
std_metrics = []
for i in range(len(value_vectors[0])):
if type(value_vectors[0][0]) is int or type(value_vectors[0][0]) is float \
or type(value_vectors[0][0]) is np.int64 or type(value_vectors[0][0]) is np.float64:
seed_values = []
for seed_idx in range(len(self.experiment_config.random_seeds)):
seed_values.append(value_vectors[seed_idx][i])
avg = ExperimentUtil.mean_confidence_interval(
data=seed_values,
confidence=self.experiment_config.hparams[agents_constants.COMMON.CONFIDENCE_INTERVAL].value)[0]
if not math.isnan(avg):
avg_metrics.append(avg)
ci = ExperimentUtil.mean_confidence_interval(
data=seed_values,
confidence=self.experiment_config.hparams[agents_constants.COMMON.CONFIDENCE_INTERVAL].value)[1]
if not math.isnan(ci):
std_metrics.append(ci)
else:
std_metrics.append(-1)
else:
avg_metrics.append(-1)
std_metrics.append(-1)
exp_result.avg_metrics[metric] = avg_metrics
exp_result.std_metrics[metric] = std_metrics
traces = self.env.get_traces()
if len(traces) > 0 and self.save_to_metastore:
MetastoreFacade.save_simulation_trace(traces[-1])
ts = time.time()
self.exp_execution.timestamp = ts
self.exp_execution.result = exp_result
if self.save_to_metastore:
MetastoreFacade.update_experiment_execution(experiment_execution=self.exp_execution,
id=self.exp_execution.id)
return self.exp_execution
[docs] def hparam_names(self) -> List[str]:
"""
:return: a list with the hyperparameter names
"""
return [constants.T_SPSA.a, constants.T_SPSA.c, constants.T_SPSA.LAMBDA,
constants.T_SPSA.A, constants.T_SPSA.EPSILON, constants.T_SPSA.N,
constants.T_SPSA.L, constants.T_SPSA.THETA1, agents_constants.COMMON.EVAL_BATCH_SIZE,
constants.T_SPSA.OBJECTIVE_TYPE,
constants.T_SPSA.GRADIENT_BATCH_SIZE, agents_constants.COMMON.CONFIDENCE_INTERVAL,
agents_constants.COMMON.RUNNING_AVERAGE]
[docs] def spsa(self, exp_result: ExperimentResult, seed: int,
training_job: TrainingJobConfig, random_seeds: List[int]) -> ExperimentResult:
"""
Runs the SPSA algorithm
:param exp_result: the experiment result object to store the result
:param seed: the seed
:param training_job: the training job config
:param random_seeds: list of seeds
:return: the updated experiment result and the trained policy
"""
start: float = time.time()
objective_type = ObjectiveType(self.experiment_config.hparams[constants.T_SPSA.OBJECTIVE_TYPE].value)
L = self.experiment_config.hparams[constants.T_SPSA.L].value
if constants.T_SPSA.THETA1 in self.experiment_config.hparams:
theta = self.experiment_config.hparams[constants.T_SPSA.THETA1].value
else:
if self.experiment_config.player_type == PlayerType.DEFENDER:
theta = TSPSAAgent.initial_theta(L=L)
else:
theta = TSPSAAgent.initial_theta(L=2 * L)
# Initial eval
policy = self.get_policy(theta=theta, L=L)
avg_metrics = self.eval_theta(
policy=policy, max_steps=self.experiment_config.hparams[agents_constants.COMMON.MAX_ENV_STEPS].value)
J = round(avg_metrics[env_constants.ENV_METRICS.RETURN], 3)
policy.avg_R = J
exp_result.all_metrics[seed][agents_constants.COMMON.AVERAGE_RETURN].append(J)
exp_result.all_metrics[seed][agents_constants.COMMON.RUNNING_AVERAGE_RETURN].append(J)
exp_result.all_metrics[seed][constants.T_SPSA.THETAS].append(TSPSAAgent.round_vec(theta))
time_elapsed_minutes = round((time.time() - start) / 60, 3)
exp_result.all_metrics[seed][agents_constants.COMMON.RUNTIME].append(time_elapsed_minutes)
# Hyperparameters
N = self.experiment_config.hparams[constants.T_SPSA.N].value
a = self.experiment_config.hparams[constants.T_SPSA.a].value
c = self.experiment_config.hparams[constants.T_SPSA.c].value
A = self.experiment_config.hparams[constants.T_SPSA.A].value
lamb = self.experiment_config.hparams[constants.T_SPSA.LAMBDA].value
epsilon = self.experiment_config.hparams[constants.T_SPSA.EPSILON].value
gradient_batch_size = self.experiment_config.hparams[constants.T_SPSA.GRADIENT_BATCH_SIZE].value
for i in range(N):
# Step sizes and perturbation size
ak = self.standard_ak(a=a, A=A, epsilon=epsilon, k=i)
ck = self.standard_ck(c=c, lamb=lamb, k=i)
# Get estimated gradient
gk = self.batch_gradient(theta=theta, ck=ck, L=L, k=i, gradient_batch_size=gradient_batch_size)
# Adjust theta using SA
if objective_type == ObjectiveType.MAX:
theta = [t + ak * gkk for t, gkk in zip(theta, gk)]
else:
theta = [t - ak * gkk for t, gkk in zip(theta, gk)]
# Constrain (Theorem 1.A, Hammar Stadler 2021)
if self.experiment_config.player_type == PlayerType.DEFENDER and "stop" in self.simulation_env_config.name:
for l in range(L - 1):
theta[l] = max(theta[l], theta[l + 1])
elif self.experiment_config.player_type == PlayerType.ATTACKER and "stop" \
in self.simulation_env_config.name:
for l in range(0, L - 1):
theta[l] = min(theta[l], theta[l + 1])
for l in range(L, 2 * L - 1):
theta[l] = max(theta[l], theta[l + 1])
elif self.experiment_config.player_type == PlayerType.DEFENDER and "tolerance" \
in self.simulation_env_config.name:
for l in range(L - 1):
theta[l] = min(theta[l], theta[l + 1])
# Evaluate new theta
policy = self.get_policy(theta=theta, L=L)
avg_metrics = self.eval_theta(
policy=policy, max_steps=self.experiment_config.hparams[agents_constants.COMMON.MAX_ENV_STEPS].value)
# Log average return
J = round(avg_metrics[env_constants.ENV_METRICS.RETURN], 3)
policy.avg_R = J
running_avg_J = ExperimentUtil.running_average(
exp_result.all_metrics[seed][agents_constants.COMMON.AVERAGE_RETURN],
self.experiment_config.hparams[agents_constants.COMMON.RUNNING_AVERAGE].value)
exp_result.all_metrics[seed][agents_constants.COMMON.AVERAGE_RETURN].append(J)
exp_result.all_metrics[seed][agents_constants.COMMON.RUNNING_AVERAGE_RETURN].append(running_avg_J)
# Log runtime
time_elapsed_minutes = round((time.time() - start) / 60, 3)
exp_result.all_metrics[seed][agents_constants.COMMON.RUNTIME].append(time_elapsed_minutes)
exp_result.all_metrics[seed][constants.T_SPSA.THETAS].append(TSPSAAgent.round_vec(theta))
if self.experiment_config.hparams[constants.T_SPSA.POLICY_TYPE] == PolicyType.MULTI_THRESHOLD:
# Log thresholds
exp_result.all_metrics[seed][constants.T_SPSA.THRESHOLDS].append(
TSPSAAgent.round_vec(policy.thresholds()))
# Log stop distribution
for k, v in policy.stop_distributions().items():
exp_result.all_metrics[seed][k].append(v)
# Log intrusion lengths
if env_constants.ENV_METRICS.INTRUSION_LENGTH in avg_metrics:
exp_result.all_metrics[seed][env_constants.ENV_METRICS.INTRUSION_LENGTH].append(
round(avg_metrics[env_constants.ENV_METRICS.INTRUSION_LENGTH], 3))
exp_result.all_metrics[seed][agents_constants.COMMON.RUNNING_AVERAGE_INTRUSION_LENGTH].append(
ExperimentUtil.running_average(
exp_result.all_metrics[seed][env_constants.ENV_METRICS.INTRUSION_LENGTH],
self.experiment_config.hparams[agents_constants.COMMON.RUNNING_AVERAGE].value))
# Log prediction distance
if env_constants.ENV_METRICS.WEIGHTED_INTRUSION_PREDICTION_DISTANCE in avg_metrics:
exp_result.all_metrics[seed][env_constants.ENV_METRICS.WEIGHTED_INTRUSION_PREDICTION_DISTANCE].append(
round(avg_metrics[env_constants.ENV_METRICS.WEIGHTED_INTRUSION_PREDICTION_DISTANCE], 3))
exp_result.all_metrics[seed][
agents_constants.COMMON.RUNNING_AVERAGE_WEIGHTED_INTRUSION_PREDICTION_DISTANCE].append(
ExperimentUtil.running_average(
exp_result.all_metrics[seed][env_constants.ENV_METRICS.WEIGHTED_INTRUSION_PREDICTION_DISTANCE],
self.experiment_config.hparams[agents_constants.COMMON.RUNNING_AVERAGE].value))
if env_constants.ENV_METRICS.START_POINT_CORRECT in avg_metrics:
exp_result.all_metrics[seed][env_constants.ENV_METRICS.START_POINT_CORRECT].append(
round(avg_metrics[env_constants.ENV_METRICS.START_POINT_CORRECT], 3))
exp_result.all_metrics[seed][
agents_constants.COMMON.RUNNING_AVERAGE_START_POINT_CORRECT].append(
ExperimentUtil.running_average(
exp_result.all_metrics[seed][env_constants.ENV_METRICS.START_POINT_CORRECT],
self.experiment_config.hparams[agents_constants.COMMON.RUNNING_AVERAGE].value))
# Log stopping times
if env_constants.ENV_METRICS.INTRUSION_START in avg_metrics:
exp_result.all_metrics[seed][env_constants.ENV_METRICS.INTRUSION_START].append(
round(avg_metrics[env_constants.ENV_METRICS.INTRUSION_START], 3))
exp_result.all_metrics[seed][agents_constants.COMMON.RUNNING_AVERAGE_INTRUSION_START].append(
ExperimentUtil.running_average(
exp_result.all_metrics[seed][env_constants.ENV_METRICS.INTRUSION_START],
self.experiment_config.hparams[agents_constants.COMMON.RUNNING_AVERAGE].value))
exp_result.all_metrics[seed][env_constants.ENV_METRICS.TIME_HORIZON].append(
round(avg_metrics[env_constants.ENV_METRICS.TIME_HORIZON], 3))
exp_result.all_metrics[seed][agents_constants.COMMON.RUNNING_AVERAGE_TIME_HORIZON].append(
ExperimentUtil.running_average(
exp_result.all_metrics[seed][env_constants.ENV_METRICS.TIME_HORIZON],
self.experiment_config.hparams[agents_constants.COMMON.RUNNING_AVERAGE].value))
for l in range(1, self.experiment_config.hparams[constants.T_SPSA.L].value + 1):
if env_constants.ENV_METRICS.STOP + f"_{l}" in avg_metrics:
exp_result.plot_metrics.append(env_constants.ENV_METRICS.STOP + f"_{l}")
exp_result.all_metrics[seed][env_constants.ENV_METRICS.STOP + f"_{l}"].append(
round(avg_metrics[env_constants.ENV_METRICS.STOP + f"_{l}"], 3))
exp_result.all_metrics[seed][env_constants.ENV_METRICS.STOP + f"_running_average_{l}"].append(
ExperimentUtil.running_average(
exp_result.all_metrics[seed][env_constants.ENV_METRICS.STOP + f"_{l}"],
self.experiment_config.hparams[agents_constants.COMMON.RUNNING_AVERAGE].value))
# Log baseline returns
if env_constants.ENV_METRICS.AVERAGE_UPPER_BOUND_RETURN in avg_metrics:
exp_result.all_metrics[seed][env_constants.ENV_METRICS.AVERAGE_UPPER_BOUND_RETURN].append(
round(avg_metrics[env_constants.ENV_METRICS.AVERAGE_UPPER_BOUND_RETURN], 3))
if env_constants.ENV_METRICS.AVERAGE_DEFENDER_BASELINE_STOP_ON_FIRST_ALERT_RETURN in avg_metrics:
exp_result.all_metrics[seed][
env_constants.ENV_METRICS.AVERAGE_DEFENDER_BASELINE_STOP_ON_FIRST_ALERT_RETURN].append(
round(avg_metrics[env_constants.ENV_METRICS.AVERAGE_DEFENDER_BASELINE_STOP_ON_FIRST_ALERT_RETURN],
3))
if i % self.experiment_config.log_every == 0:
# Update training job
total_iterations = len(random_seeds) * N
iterations_done = (random_seeds.index(seed)) * N + i
progress = round(iterations_done / total_iterations, 2)
training_job.progress_percentage = progress
training_job.experiment_result = exp_result
if self.env is not None and len(self.env.get_traces()) > 0:
training_job.simulation_traces.append(self.env.get_traces()[-1])
if len(training_job.simulation_traces) > training_job.num_cached_traces:
training_job.simulation_traces = training_job.simulation_traces[1:]
if self.save_to_metastore:
MetastoreFacade.update_training_job(training_job=training_job, id=training_job.id)
# Update execution
ts = time.time()
self.exp_execution.timestamp = ts
self.exp_execution.result = exp_result
if self.save_to_metastore:
MetastoreFacade.update_experiment_execution(experiment_execution=self.exp_execution,
id=self.exp_execution.id)
opt_j = -1
if len(exp_result.all_metrics[seed][env_constants.ENV_METRICS.AVERAGE_UPPER_BOUND_RETURN]) > 0:
opt_j = exp_result.all_metrics[seed][env_constants.ENV_METRICS.AVERAGE_UPPER_BOUND_RETURN][-1]
Logger.__call__().get_logger().info(
f"[T-SPSA] i: {i}, J:{J}, "
f"J_avg_{self.experiment_config.hparams[agents_constants.COMMON.RUNNING_AVERAGE].value}:"
f"{running_avg_J}, "
f"opt_J:{opt_j}, "
f"theta:{policy.theta}, progress: {round(progress * 100, 2)}%, runtime: {time_elapsed_minutes} min")
policy = self.get_policy(theta=theta, L=L)
exp_result.policies[seed] = policy
# Save policy
if self.save_to_metastore:
MetastoreFacade.save_multi_threshold_stopping_policy(multi_threshold_stopping_policy=policy)
return exp_result
[docs] def eval_theta(self, policy: Union[MultiThresholdStoppingPolicy, LinearThresholdStoppingPolicy],
max_steps: int = 200) -> Dict[str, Any]:
"""
Evaluates a given threshold policy by running monte-carlo simulations
:param policy: the policy to evaluate
:return: the average metrics of the evaluation
"""
if self.env is None:
raise ValueError("An environment must be specified to run policy evaluation")
eval_batch_size = self.experiment_config.hparams[agents_constants.COMMON.EVAL_BATCH_SIZE].value
metrics: Dict[str, Any] = {}
for j in range(eval_batch_size):
done = False
o, _ = self.env.reset()
l = int(o[0])
b1 = o[1]
t = 1
r = 0
a = 0
info: Dict[str, Any] = {}
while not done and t <= max_steps:
Logger.__call__().get_logger().debug(f"t:{t}, a:{a}, b1:{b1}, r:{r}, l:{l}, info:{info}")
if self.experiment_config.player_type == PlayerType.ATTACKER:
policy.opponent_strategy = self.env.static_defender_strategy
a = policy.action(o=o)
else:
a = policy.action(o=o)
o, r, done, _, info = self.env.step(a)
l = int(o[0])
b1 = o[1]
t += 1
metrics = TSPSAAgent.update_metrics(metrics=metrics, info=info)
avg_metrics = TSPSAAgent.compute_avg_metrics(metrics=metrics)
return avg_metrics
[docs] @staticmethod
def update_metrics(metrics: Dict[str, List[Union[float, int]]], info: Dict[str, Union[float, int]]) \
-> Dict[str, List[Union[float, int]]]:
"""
Update a dict with aggregated metrics using new information from the environment
:param metrics: the dict with the aggregated metrics
:param info: the new information
:return: the updated dict
"""
for k, v in info.items():
if k in metrics:
metrics[k].append(round(v, 3))
else:
metrics[k] = [v]
return metrics
[docs] @staticmethod
def compute_avg_metrics(metrics: Dict[str, List[Union[float, int]]]) -> Dict[str, Union[float, int]]:
"""
Computes the average metrics of a dict with aggregated metrics
:param metrics: the dict with the aggregated metrics
:return: the average metrics
"""
avg_metrics = {}
for k, v in metrics.items():
avg = round(sum(v) / len(v), 2)
avg_metrics[k] = avg
return avg_metrics
[docs] @staticmethod
def standard_ak(a: int, A: int, epsilon: float, k: int) -> float:
"""
Gets the step size for gradient ascent at iteration k
:param a: a scalar hyperparameter
:param A: a scalar hyperparameter
:param epsilon: the epsilon scalar hyperparameter
:param k: the iteration index
:return: the step size a_k
"""
return float(a / (k + 1 + A) ** epsilon)
[docs] @staticmethod
def standard_ck(c: float, lamb: float, k: int) -> float:
"""
Gets the step size of perturbations at iteration k
:param c: a scalar hyperparameter
:param lamb: (lambda) a scalar hyperparameter
:param k: the iteration
:return: the pertrubation step size
"""
return float(c / (k + 1) ** lamb)
[docs] @staticmethod
def standard_deltak(dimension: int, k: int) -> List[float]:
"""
Gets the perturbation direction at iteration k
:param k: the iteration
:param dimension: the dimension of the perturbation vector
:return: delta_k the perturbation vector at iteration k
"""
return [random.choice((-1, 1)) for _ in range(dimension)]
[docs] @staticmethod
def initial_theta(L: int) -> npt.NDArray[Any]:
"""
Initializes theta randomly
:param L: the dimension of theta
:return: the initialized theta vector
"""
theta_1 = []
for k in range(L):
theta_1.append(np.random.uniform(-3, 3))
return np.array(theta_1)
[docs] def batch_gradient(self, theta: List[float], ck: float, L: int, k: int,
gradient_batch_size: int = 1):
"""
Computes a batch of gradients and returns the average
:param theta: the current parameter vector
:param k: the current training iteration
:param ck: the perturbation step size
:param L: the total number of stops for the defender
:param gradient_batch_size: the number of gradients to include in the batch
:return: the average of the batch of gradients
"""
gradients = []
for i in range(gradient_batch_size):
deltak_i = self.standard_deltak(dimension=len(theta), k=k)
gk_i = self.estimate_gk(theta=theta, deltak=deltak_i, ck=ck, L=L)
gradients.append(gk_i)
batch_gk = (np.matrix(gradients).sum(axis=0) * (1 / gradient_batch_size)).tolist()[0]
return batch_gk
[docs] def estimate_gk(self, theta: List[float], deltak: List[float], ck: float, L: int):
"""
Estimate the gradient at iteration k of the T-SPSA algorithm
:param theta: the current parameter vector
:param deltak: the perturbation direction vector
:param ck: the perturbation step size
:param L: the total number of stops for the defender
:return: the estimated gradient
"""
# Get the two perturbed values of theta
# list comprehensions like this are quite nice
ta = [t + ck * dk for t, dk in zip(theta, deltak)]
tb = [t - ck * dk for t, dk in zip(theta, deltak)]
# Calculate g_k(theta_k)
policy_1 = self.get_policy(theta=ta, L=L)
avg_metrics = self.eval_theta(
policy_1,
max_steps=self.experiment_config.hparams[agents_constants.COMMON.MAX_ENV_STEPS].value
)
J_a = round(avg_metrics[env_constants.ENV_METRICS.RETURN], 3)
policy_2 = self.get_policy(theta=tb, L=L)
avg_metrics = self.eval_theta(
policy_2,
max_steps=self.experiment_config.hparams[agents_constants.COMMON.MAX_ENV_STEPS].value)
J_b = round(avg_metrics[env_constants.ENV_METRICS.RETURN], 3)
gk = [(J_a - J_b) / (2 * ck * dk) for dk in deltak]
return gk
[docs] @staticmethod
def round_vec(vec) -> List[float]:
"""
Rounds a vector to 3 decimals
:param vec: the vector to round
:return: the rounded vector
"""
return list(map(lambda x: round(x, 3), vec))
[docs] def get_policy(self, theta: List[float], L: int) -> Union[MultiThresholdStoppingPolicy,
LinearThresholdStoppingPolicy]:
"""
Gets the policy from a parameter vector
:param theta: the parameter vector
:param L: the number of parameters
:return: the policy object
"""
if self.experiment_config.hparams[constants.T_SPSA.POLICY_TYPE].value == PolicyType.MULTI_THRESHOLD.value:
policy = MultiThresholdStoppingPolicy(
theta=theta, simulation_name=self.simulation_env_config.name,
states=self.simulation_env_config.state_space_config.states,
player_type=self.experiment_config.player_type, L=L,
actions=self.simulation_env_config.joint_action_space_config.action_spaces[
self.experiment_config.player_idx].actions, experiment_config=self.experiment_config, avg_R=-1,
agent_type=AgentType.T_SPSA)
else:
policy = LinearThresholdStoppingPolicy(
theta=theta, simulation_name=self.simulation_env_config.name,
states=self.simulation_env_config.state_space_config.states,
player_type=self.experiment_config.player_type, L=L,
actions=self.simulation_env_config.joint_action_space_config.action_spaces[
self.experiment_config.player_idx].actions, experiment_config=self.experiment_config, avg_R=-1,
agent_type=AgentType.T_SPSA)
return policy