Copy # threat_api.py — production ML inference service
API_CODE = '''
from fastapi import FastAPI, HTTPException, Depends, Header, Request
from pydantic import BaseModel, Field, validator
from typing import List, Optional
import numpy as np, pickle, time, hashlib, logging
from collections import defaultdict
# --- Setup ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("threat_api")
app = FastAPI(title="Threat Detection API", version="1.0.0",
description="Real-time ML-powered threat scoring")
# --- Load model ---
import sklearn.ensemble # ensure unpickling works
MODEL_STORE = {} # lazy-loaded
def get_model():
if "model" not in MODEL_STORE:
with open("/tmp/model.pkl", "rb") as f:
MODEL_STORE.update(pickle.load(f))
return MODEL_STORE
# --- Request/Response schemas ---
class ThreatFeatures(BaseModel):
feat_00: float = Field(..., ge=-10, le=10, description="Normalised feature 0")
feat_01: float = Field(..., ge=-10, le=10)
feat_02: float = Field(..., ge=-10, le=10)
feat_03: float = Field(..., ge=-10, le=10)
feat_04: float = Field(..., ge=-10, le=10)
feat_05: float = Field(..., ge=-10, le=10)
feat_06: float = Field(..., ge=-10, le=10)
feat_07: float = Field(..., ge=-10, le=10)
feat_08: float = Field(..., ge=-10, le=10)
feat_09: float = Field(..., ge=-10, le=10)
@validator("*", pre=True)
def check_finite(cls, v):
if not np.isfinite(v): raise ValueError("Feature must be finite")
return v
class ThreatScore(BaseModel):
threat_score: float
is_threat: bool
confidence: str
model_version:str
latency_ms: float
class BatchRequest(BaseModel):
events: List[ThreatFeatures]
# --- Auth middleware ---
VALID_API_KEYS = {"soc-prod-key-abc123", "siem-integration-xyz456"}
def verify_api_key(x_api_key: str = Header(...)):
if x_api_key not in VALID_API_KEYS:
raise HTTPException(status_code=401, detail="Invalid API key")
return x_api_key
# --- Rate limiting ---
request_counts = defaultdict(list)
def check_rate_limit(api_key: str = Depends(verify_api_key)):
now = time.time()
counts = request_counts[api_key]
counts[:] = [t for t in counts if now - t < 60]
if len(counts) >= 1000:
raise HTTPException(status_code=429, detail="Rate limit: 1000 req/min")
counts.append(now)
return api_key
# --- Endpoints ---
@app.get("/health")
def health():
return {"status": "ok", "model_loaded": "model" in MODEL_STORE,
"version": "1.0.0", "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ")}
@app.post("/predict", response_model=ThreatScore)
def predict(features: ThreatFeatures, api_key: str = Depends(check_rate_limit)):
t0 = time.perf_counter()
store = get_model()
x = np.array(list(features.dict().values())).reshape(1, -1)
score = float(store["model"].predict_proba(x)[0, 1])
latency = (time.perf_counter() - t0) * 1000
logger.info(f"predict: score={score:.4f} latency={latency:.1f}ms key={api_key[:8]}...")
return ThreatScore(
threat_score=round(score, 4), is_threat=(score >= 0.5),
confidence="HIGH" if abs(score-0.5) > 0.3 else "MEDIUM" if abs(score-0.5) > 0.1 else "LOW",
model_version=store["version"], latency_ms=round(latency, 2),
)
@app.post("/predict/batch")
def predict_batch(req: BatchRequest, api_key: str = Depends(check_rate_limit)):
t0 = time.perf_counter()
store = get_model()
X = np.array([list(e.dict().values()) for e in req.events])
scores = store["model"].predict_proba(X)[:, 1]
return {"predictions": [{"threat_score": round(float(s),4), "is_threat": bool(s>=0.5)}
for s in scores],
"n_threats": int((scores>=0.5).sum()),
"latency_ms": round((time.perf_counter()-t0)*1000, 2)}
@app.get("/model/info")
def model_info(api_key: str = Depends(verify_api_key)):
store = get_model()
return {"version": store["version"], "features": store["feature_names"],
"trained_at": store["trained_at"], "n_features": len(store["feature_names"])}
'''
print("FastAPI application code defined.")
print(f"Endpoints:")
for line in API_CODE.split('\n'):
if '@app.' in line:
print(f" {line.strip()}")