Copy import numpy as np
class ReplayBuffer:
"""Experience replay: store and sample past transitions"""
def __init__(self, capacity: int = 10000):
self.buffer = []
self.capacity = capacity
self.position = 0
def push(self, state, action, reward, next_state, done):
if len(self.buffer) < self.capacity:
self.buffer.append(None)
self.buffer[self.position] = (state, action, reward, next_state, done)
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size: int) -> list:
idx = np.random.choice(len(self.buffer), batch_size, replace=False)
return [self.buffer[i] for i in idx]
def __len__(self): return len(self.buffer)
class DQNNetwork:
"""Simple neural network for Q-value estimation"""
def __init__(self, state_dim: int, action_dim: int, hidden: int = 64):
np.random.seed(42)
self.W1 = np.random.randn(state_dim, hidden) * np.sqrt(2/state_dim)
self.b1 = np.zeros(hidden)
self.W2 = np.random.randn(hidden, hidden) * np.sqrt(2/hidden)
self.b2 = np.zeros(hidden)
self.W3 = np.random.randn(hidden, action_dim) * np.sqrt(2/hidden)
self.b3 = np.zeros(action_dim)
def forward(self, x: np.ndarray) -> np.ndarray:
h1 = np.maximum(0, x @ self.W1 + self.b1)
h2 = np.maximum(0, h1 @ self.W2 + self.b2)
return h2 @ self.W3 + self.b3
def update(self, states, actions, targets, lr: float = 0.001):
q_vals = self.forward(states)
errors = np.zeros_like(q_vals)
for i, (a, t) in enumerate(zip(actions, targets)):
errors[i, a] = q_vals[i, a] - t
# Backprop (simplified)
h1 = np.maximum(0, states @ self.W1 + self.b1)
h2 = np.maximum(0, h1 @ self.W2 + self.b2)
dW3 = h2.T @ errors / len(states)
dh2 = errors @ self.W3.T * (h2 > 0)
dW2 = h1.T @ dh2 / len(states)
dh1 = dh2 @ self.W2.T * (h1 > 0)
dW1 = states.T @ dh1 / len(states)
self.W3 -= lr * dW3; self.W2 -= lr * dW2; self.W1 -= lr * dW1
class DQNAgent:
"""DQN with experience replay and target network"""
def __init__(self, state_dim: int, action_dim: int,
gamma: float = 0.95, epsilon: float = 1.0,
batch_size: int = 64, target_update: int = 50):
self.state_dim = state_dim
self.action_dim = action_dim
self.gamma = gamma
self.epsilon = epsilon
self.epsilon_min = 0.05
self.epsilon_decay= 0.995
self.batch_size = batch_size
self.target_update= target_update
self.step_count = 0
self.q_net = DQNNetwork(state_dim, action_dim)
self.target_net= DQNNetwork(state_dim, action_dim)
self._sync_target()
self.memory = ReplayBuffer()
def _sync_target(self):
"""Periodically copy online → target network (stabilises training)"""
for attr in ['W1','b1','W2','b2','W3','b3']:
setattr(self.target_net, attr, getattr(self.q_net, attr).copy())
def act(self, state: np.ndarray) -> int:
if np.random.random() < self.epsilon:
return np.random.randint(self.action_dim)
q_vals = self.q_net.forward(state.reshape(1, -1))[0]
return int(np.argmax(q_vals))
def learn(self):
if len(self.memory) < self.batch_size:
return
batch = self.memory.sample(self.batch_size)
states = np.array([b[0] for b in batch])
actions = [b[1] for b in batch]
rewards = np.array([b[2] for b in batch])
nexts = np.array([b[3] for b in batch])
dones = np.array([b[4] for b in batch])
# DQN target: r + γ * max_a' Q_target(s', a')
next_q = self.target_net.forward(nexts).max(1)
targets_arr = rewards + self.gamma * next_q * (1 - dones)
self.q_net.update(states, actions, targets_arr)
self.step_count += 1
if self.step_count % self.target_update == 0:
self._sync_target()
if self.epsilon > self.epsilon_min:
self.epsilon *= self.epsilon_decay
def push(self, state, action, reward, next_state, done):
self.memory.push(state, action, reward, next_state, done)
dqn_agent = DQNAgent(state_dim=5, action_dim=5, batch_size=32)
dqn_rewards = []
print("Training DQN agent (300 episodes)...")
for ep in range(300):
state = env.reset(); total = 0
for _ in range(env.MAX_STEPS):
action = dqn_agent.act(state)
ns, r, done = env.step(action)
dqn_agent.push(state, action, r, ns, done)
dqn_agent.learn()
state = ns; total += r
if done: break
dqn_rewards.append(total)
print(f"\nDQN vs Q-Learning comparison:")
for label, rews in [("Q-Learning", rewards), ("DQN", dqn_rewards)]:
early = np.mean(rews[:50])
late = np.mean(rews[-50:])
print(f" {label:<12}: early={early:>8.2f} late={late:>8.2f} improvement={late-early:>+8.2f}")