Copy import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
class CounterfactualExplainer:
"""
Generate counterfactual explanations: "What is the minimum change
to flip the model's prediction?"
Security use case: "What would need to change for this host to be
classified as low-risk?"
Method: gradient-free optimisation (coordinate descent on feature space)
"""
def __init__(self, model, scaler, feature_names: list):
self.model = model
self.scaler = scaler
self.features = feature_names
def find_counterfactual(self, x: np.ndarray, target_class: int = 0,
max_iter: int = 200, step: float = 0.05) -> tuple:
"""Find minimal perturbation to flip prediction"""
x_cf = x.copy()
original_pred = self.model.predict(x.reshape(1,-1))[0]
for i in range(max_iter):
pred = self.model.predict(x_cf.reshape(1,-1))[0]
if pred == target_class:
return x_cf, i
# Gradient-free: perturb each feature, keep change that most increases P(target)
best_grad, best_feat = 0, None
for f in range(len(x_cf)):
for delta in [-step, +step]:
x_trial = x_cf.copy(); x_trial[f] += delta
prob = self.model.predict_proba(x_trial.reshape(1,-1))[0, target_class]
if prob > best_grad:
best_grad = prob; best_feat = (f, delta)
if best_feat: x_cf[best_feat[0]] += best_feat[1]
return x_cf, max_iter # not converged
def explain(self, x: np.ndarray, target_class: int = 0):
x_cf, n_steps = self.find_counterfactual(x, target_class)
changes = [(self.features[i], round(float(x[i]),3), round(float(x_cf[i]),3),
round(float(x_cf[i]-x[i]),3))
for i in range(len(x)) if abs(x_cf[i]-x[i]) > 0.01]
changes.sort(key=lambda c: abs(c[3]), reverse=True)
return {'counterfactual': x_cf, 'changes': changes[:5], 'n_steps': n_steps}
# Train model on security risk features
feature_names = ['cvss_score', 'days_unpatched', 'network_exposure',
'host_type_server', 'n_open_ports', 'has_edr',
'internet_facing', 'admin_access_count']
n = 1000
X_risk = np.column_stack([
np.random.uniform(0, 10, n), # cvss
np.random.randint(0, 365, n), # days unpatched
np.random.uniform(0, 1, n), # network exposure
np.random.binomial(1, 0.3, n), # is server
np.random.randint(1, 50, n), # open ports
np.random.binomial(1, 0.7, n), # has EDR
np.random.binomial(1, 0.2, n), # internet facing
np.random.randint(1, 20, n), # admin access count
])
y_risk = ((X_risk[:, 0] > 7) | (X_risk[:, 1] > 180) | (X_risk[:, 2] > 0.7)).astype(int)
scaler_r = StandardScaler(); X_r_s = scaler_r.fit_transform(X_risk)
rf_risk = RandomForestClassifier(n_estimators=100, random_state=42)
rf_risk.fit(X_r_s, y_risk)
# Explain a high-risk host
high_risk_host = np.array([8.5, 200, 0.85, 1, 35, 0, 1, 8])
x_s = scaler_r.transform(high_risk_host.reshape(1,-1)).ravel()
explainer = CounterfactualExplainer(rf_risk, scaler_r, feature_names)
result = explainer.explain(x_s, target_class=0)
print("Counterfactual Explanation for High-Risk Host:\n")
print(f"Original prediction: HIGH RISK (class 1)")
print(f"Target: LOW RISK (class 0)")
print(f"Steps to flip: {result['n_steps']}\n")
print(f"{'Feature':<25} {'Original':>10} {'Counterfactual':>16} {'Change':>8}")
print("-" * 63)
for feat, orig, cf, delta in result['changes']:
print(f"{feat:<25} {orig:>10.3f} {cf:>16.3f} {delta:>+8.3f}")
print("\n→ Remediation actions: patch (reduce days_unpatched), reduce exposure, add EDR")