Experiment 1#



Exp1#

The first exp will test to replicate one of the Anthropic research outcome (Elhage et al., 2022) (see the left square box in the diagram below), compressing feature representations of a toy model from a larger dim into a smaller dim (2 dim) will lead to represent only two features (drawing approx 2 features orthogonally), when sparsity is not set (sparsity=0).

Here,

  • a toy model is composed of a one layer Linear model with a ReLU filtering function.

  • compressing the toy model from a larger dim (49 dim) into a smaller dim (2 dim) via autoencoder model.

  • training the toy model on the MNIST hand-written digit image (28x28) datasets.

  • exploring if the datasets with the imagined 49 features will converge into 2 feature representations.


alt text



Exp1.0. Prepare packages#



### import packages 

# ml/nn
import torch
import torch.nn as nn  # all neural network modules
import torch.nn.functional as F  # Functions with no parameters -> activation functions
import torch.optim as optim  # optimization algo

from torch.utils.data import DataLoader # easier dataset management, helps create mini batches
from torch.utils.data import random_split # set train-test ratio

# import torchvision
import torchvision.datasets as datasets  # standard datasets
import torchvision.transforms as transforms # this for convert dataset to tensor
from torchvision.utils import make_grid # this for visualization

# stats/ml #1
import numpy as np
import matplotlib.pyplot as plt

# stats/ml #2
from sklearn import metrics
from sklearn import decomposition
from sklearn import manifold



Exp1.1. Run first exp#



### set device

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
cpu
### Define all parameters

# later these will be inserted at the beginning as args

# fixed
batch_size = 64
input_size = 784
num_classes = 10
hidden_size2 = 2

# variable
hidden_size = 49
lr1 = 0.001  # lr for opt1 
lr2 = 0.0001 # lr for opt2 # 0.001, or etc 
num_epochs1 = 10
num_epochs2 = 20
### define original functions

def norm_point(x, y):
    # Calculate the length of the vector
    length = np.sqrt(x**2 + y**2)
    
    # Check where length is 0
    zero_mask = (length == 0)  # Boolean mask for elements where length is 0

    # Initialize normalized coordinates
    x_norm = np.zeros_like(x)
    y_norm = np.zeros_like(y)
    
    # Perform normalization only where length != 0
    x_norm[~zero_mask] = x[~zero_mask] / length[~zero_mask]
    y_norm[~zero_mask] = y[~zero_mask] / length[~zero_mask]
    
    # For dimensions where length == 0, retain the original values
    x_norm[zero_mask] = x[zero_mask]
    y_norm[zero_mask] = y[zero_mask]

    return x_norm, y_norm
### implement two single FC NN models

class lin_model(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):  
        super(lin_model, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size) # input: 28x28=784, hidden: 7x7=49
        self.fc2 = nn.Linear(hidden_size, num_classes) # hidden: 49, num_classes: 10
    def forward(self, x):
        x = self.fc1(x) # 784->49
        x = torch.tanh(x) # range within [-1:1]
        x_h = F.relu(x) # add non-linearity
        x = self.fc2(x_h) # 49->10
        return x, x_h

class ae_model(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(ae_model, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size) # input: 49, target_dim: 2, for visualising superposition in 2d
        self.fc2 = nn.Linear(hidden_size, input_size) # input: 2, target_dim: 49
    def forward(self, x):
        x = self.fc1(x) # 49->2
        x_2d = torch.tanh(x) # range within [-1:1]
        x = self.fc2(x_2d) # 2->49
        x = torch.tanh(x) # range within [-1:1]
        x = F.relu(x) # simulate lin_model's process
        return x, x_2d
### define datasets

# load dataset
train_dataset = datasets.MNIST(root='../dataset/', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = datasets.MNIST(root='../dataset/', train=False, transform=transforms.ToTensor(), download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)
### initialize model, loss, optimizer

m1 = lin_model(
    input_size=input_size, 
    hidden_size=hidden_size, 
    num_classes=num_classes
).to(device)

m2 = ae_model(
    input_size=hidden_size, 
    hidden_size=hidden_size2, 
    num_classes=num_classes
).to(device)

# Loss and Optimizer
criterion1 = nn.CrossEntropyLoss()
criterion2 = nn.MSELoss()

opt1 = optim.Adam(m1.parameters(), lr=lr1) 
opt2 = optim.Adam(m2.parameters(), lr=lr2)
### train and validaiton

num_epochs1 = 10
num_epochs2 = 20

###
### phase 1: prepare trained weights with larger dim
###

# training lin model
for epoch in range(num_epochs1):

    # Training
    for batch_idx, (data, targets) in enumerate(train_loader):

        # reshape [batch, 1, 28,28] to [batch, 28*28]
        data = data.reshape(-1, 28*28)

        # Step 1: Forward pass through model1
        out1, h1 = m1(data) # out1:[batch, 10] as 10 class, h1:[batch, 49] as 49 latent
        
        # Step 4: Compute loss1 and update model1
        loss1 = criterion1(out1, targets)
        opt1.zero_grad()
        loss1.backward()
        opt1.step()

    num_correct = 0
    num_samples = 0
    
    # Validation
    with torch.no_grad():
        for batch_idx, (data, targets) in enumerate(test_loader):
            data = data.reshape(-1, 28*28)
            out1, _ = m1(data)
            _, predictions = out1.max(1)
            num_correct += (predictions == targets).sum()
            num_samples += predictions.size(0)

        accuracy = num_correct/num_samples
    
    print(f"Epoch [{epoch+1}/{num_epochs1}], loss1: {loss1.item():.4f}, accuracy: {accuracy.item():.4f}")

###
### phase 2: train AE model weights with 2 dim
###

# training ae model and visualize results
for epoch in range(num_epochs2):

    # Training autoencoder model
    for batch_idx, (data, targets) in enumerate(train_loader):

        # reshape [batch, 1, 28,28] to [batch, 28*28]
        data = data.reshape(-1, 28*28)

        with torch.no_grad():
            _, h1 = m1(data) # out1:[batch, 10] as 10 class, h1:[batch, 49] as 49 latent
            h1_clone = h1.detach()

        out2, _ = m2(h1_clone)  # 49 -> 2 -> 49

        loss2 = criterion2(out2, h1_clone)
        opt2.zero_grad()
        loss2.backward()
        opt2.step()

    # visualization
    with torch.no_grad():

        # Collecting data # lin model
        for batch_idx, (data, targets) in enumerate(train_loader):

            data = data.reshape(-1, 28*28)
            _, h1 = m1(data)
            
            if batch_idx == 0:
                target_h1 = h1
            else:
                target_h1 = torch.vstack((target_h1,h1))

        # Collecting data # ae model
        reconstruct_target_h1, h2_2d = m2(target_h1)
        
        # Collecting data # original loss
        target_loss2_orig = criterion2(reconstruct_target_h1, target_h1)

        # Importance and xy2d coord calculation        
        batch_size, large_model_dim = target_h1.shape
        importance = []
        target_dim_2d = []

        for one_dim in range(large_model_dim):

            # importance
            h1_perturb = target_h1.clone()
            h1_perturb[:,one_dim] = 0.0
            reconstruct_h1_perturb, _ = m2(h1_perturb)
            loss2_perturb = criterion2(reconstruct_h1_perturb, target_h1)
            importance.append(loss2_perturb - target_loss2_orig)

            # xy_2d
            target_h1_one_dim = torch.zeros_like(target_h1, dtype=torch.float) # batch, larger dim
            target_h1_one_dim[:,one_dim] = target_h1[:,one_dim]
            _, target_h1_2d = m2(target_h1_one_dim)
            target_dim_2d.append(target_h1_2d.mean(axis=0))
        
        target_dim_2d = np.array(target_dim_2d)
        
        # calculate importance according to contribution to MSE loss
        importance = np.array([imp/sum(importance) for imp in importance])
        # importance[importance<(1/large_model_dim)] = 0.0 # convert importance lower than expected contribution into 0
        importance_norm = (importance - np.min(importance)) / (np.max(importance) - np.min(importance)) # normalize value into 0-1 range

        # compute xy_coordinates

        # use raw xy2d with unrestricted magnitude vector as it is
        xy_coord = importance_norm[:, np.newaxis] * target_dim_2d # because calculate [large_dim,]*[large_dim,xy_coord]
        # xy_coord = xy_coord/np.abs(xy_coord).max() # adjust max magnitude to -1.0 to 1.0
        x_coord = xy_coord[:,0]
        y_coord = xy_coord[:,1]

        # set a magnitude of xy2d with 1
        x_coord_norm, y_coord_norm = norm_point(target_dim_2d[:,0], target_dim_2d[:,1])
        x_coord_norm = x_coord_norm * importance_norm
        y_coord_norm = y_coord_norm * importance_norm

        # graph visualization
        fig, ax = plt.subplots(1,2, figsize=(7,3))
        for i in range(len(x_coord)):
            ax[0].scatter(x_coord[i], y_coord[i])
            ax[0].plot([0,x_coord[i]],[0,y_coord[i]])
            ax[1].scatter(x_coord_norm[i], y_coord_norm[i])
            ax[1].plot([0,x_coord_norm[i]],[0,y_coord_norm[i]])

        ax[0].title.set_text(f"imp * len=flex")
        ax[0].axhline(0, color='gray', linestyle='--', linewidth=0.5)
        ax[0].axvline(0, color='gray', linestyle='--', linewidth=0.5)
        ax[1].title.set_text(f"imp * len=fix1")
        ax[1].axhline(0, color='gray', linestyle='--', linewidth=0.5)
        ax[1].axvline(0, color='gray', linestyle='--', linewidth=0.5)
        plt.tight_layout()
        plt.show()

    print(f"Epoch [{epoch+1}/{num_epochs2}], loss2: {loss2.item():.4f}")
Epoch [1/10], loss1: 0.1395, accuracy: 0.9274
Epoch [2/10], loss1: 0.1529, accuracy: 0.9430
Epoch [3/10], loss1: 0.3065, accuracy: 0.9490
Epoch [4/10], loss1: 0.1372, accuracy: 0.9533
Epoch [5/10], loss1: 0.1572, accuracy: 0.9571
Epoch [6/10], loss1: 0.0578, accuracy: 0.9605
Epoch [7/10], loss1: 0.1927, accuracy: 0.9600
Epoch [8/10], loss1: 0.0783, accuracy: 0.9628
Epoch [9/10], loss1: 0.1768, accuracy: 0.9640
Epoch [10/10], loss1: 0.1349, accuracy: 0.9653
../_images/53d27b92c957bc8d9e9db6e1dc1dfe838cc6cd395e3b7a67639a017eb088a548.png
Epoch [1/20], loss2: 0.2728
../_images/7e2c9e511077a57ef6fe158e5de93f79052ad4f9d640a1b1c6490b0988895062.png
Epoch [2/20], loss2: 0.2346
../_images/3147857f5a8ec06129339aeb8b58746b562cf172acbfa6d5e773aa24eee440b1.png
Epoch [3/20], loss2: 0.2213
../_images/6c27c9eb6b41d1b11050957657d0879e463a4179c15aaf5bb6eba1fc6ec4e4a8.png
Epoch [4/20], loss2: 0.2009
../_images/452bc3e6bbe084d604b6650f6c2e7b9d7ec2be397f76e9ac8db84a9fea078dce.png
Epoch [5/20], loss2: 0.1969
../_images/d6ee8df285f32e9bbdfd4cb205440450767e3ce81c3644e716487f0991618105.png
Epoch [6/20], loss2: 0.1947
../_images/eb1d3485c88e39bcbb8c0083ca654c11e0252fee63b7aa1310bda366541039ee.png
Epoch [7/20], loss2: 0.2021
../_images/353673057ec958135776494e08a94f04a6062cb24580724e554ece6cee69e8d8.png
Epoch [8/20], loss2: 0.1921
../_images/3a86d4e357d50bb5b2190c114706b1a2f65e50d7688625f5ec43626783edb983.png
Epoch [9/20], loss2: 0.1978
../_images/ba185b21d085be0b37323e40c2c8f21ad1dfbe023353e47d32eba94d4632c19a.png
Epoch [10/20], loss2: 0.1860
../_images/2fd5ee6a6daa31b4250e730a3a4e020de44452fe59b6935a6d8271bb837b4199.png
Epoch [11/20], loss2: 0.1952
../_images/8727089f6995a851157e31f61627fa0dd6c475153617bcc99a5366b572189789.png
Epoch [12/20], loss2: 0.1888
../_images/554a2c87dd80f5e0d73c99ba710da6da5e247d4a38b9e529b4f424b2ff8e249e.png
Epoch [13/20], loss2: 0.1852
../_images/ef54e4789a9f4d19fa351c0dedb98a1130a9a8d8c766843554d7e6d2af48d936.png
Epoch [14/20], loss2: 0.1772
../_images/3718124c51e9bdf5e70cd44b86b54d9c219ec3b8e7ecb02e29c9a775598b570b.png
Epoch [15/20], loss2: 0.1638
../_images/d10efa84c890c2d9b3ae9f2eb12d3b8f63ef79ba4dc02e1a284ad3096f8acc89.png
Epoch [16/20], loss2: 0.1621
../_images/2b6627690d108961d139fd6e6b0ff101b93c7aded62583794e36368a7fd5d577.png
Epoch [17/20], loss2: 0.1632
../_images/f6e714ad9f3f5e1302843c9c776b580a3f1cce9810932ae399e69e95164352e3.png
Epoch [18/20], loss2: 0.1743
../_images/a7d433b134f704e626804070045998a8f36b8b637e96fecb8af3dc229bec2e32.png
Epoch [19/20], loss2: 0.1532
../_images/1d4631cc463b0eefb745f42085b407e9be4fd8c6cb622a4b806381571c9301e4.png
Epoch [20/20], loss2: 0.1584



Exp1.2. Result#

  • m1 = the one layer linear model outputing a hidden layer dim (49 dim) while training the 10 class hand-wrriten digit images (28x28 = 784 dim).

  • m2 = the AE model training the larger hidden layer dim (49 dim) representing each feature in a smaller hidden layer dim (2 dim).

  • the left boxes are graphs representing each of 49 features computed by m2 output (2 dim, xy coordinates) multiplying with its importance.

  • the right boxes are graphs representing each of 49 features computed by a vector of m2 output (2 dim, xy coordinates) with fixed magnitude=1 multiplying with its importance.

  • 49 features are aggregated into around 2 features over 20 epochs training.

  • In the middle of epochs training, representations become random but gradually converge into 2 dim.



Exp1.3. Run second exp#

The second exp will test to replicate another Anthropic research outcome (Elhage et al., 2022) (see the first column of the ReLU Output Model in the diagram below). Namely, testing if the top important features (2 dim) are clearly represented in the toy model compared to the remaining least important features (47 dim), when sparsity is not set (sparsity=0).

Here,

  • W in h = W * x is an AE transformation from x(49dim) to h(2dim).

  • b is a bias in h = W * x + b.

  • WT in x’ = ReLU(WT * W * x + b) is a reconstruction from h(2dim) to x’(49 dim)

  • WT * W represents cosine similarity matrix between features (48x48 matrix).

  • ||\(W_i\)|| tests if each feature is clearly represented.

  • \(\sum_{j} (\hat{x}_i*x_j)^2\) calculates similarity of a target feature, \(\hat{x}_i\), with the remaining features.



alt text



### demonstrate wT*w, W-norm, cos sim distance

with torch.no_grad():
    w1 = m2.fc1.weight.detach().numpy()
    w2 = m2.fc2.weight.detach().numpy()
    b1 = m2.fc1.bias.detach().numpy()
    b2 = m2.fc2.bias.detach().numpy()
    
    # Sort the columns of 'a' based on importance
    sorted_indices = np.argsort(importance_norm)[::-1] # Get indices that would sort the importance array
    w1_sorted = w1[:, sorted_indices] # Reorder the columns of 'a' based on sorted indices
    w2_sorted = w2[sorted_indices, :] # Reorder the columns of 'a' based on sorted indices
    b2_sorted = b2[sorted_indices][::-1] # Reorder the columns of 'a' based on sorted indices, and transpose as wT

    w1_sorted_t = w1_sorted.T
    mask_range = np.array(range(len(w1_sorted_t)))
    w1_cos_sim = []
    w2_cos_sim = []

    for i in range(len(w1_sorted_t)):
        w1_sorted_t_target = w1_sorted_t[i]
        idx_mask = (mask_range == i)
        w1_sorted_t_ref = w1_sorted_t[~idx_mask]
        cos_sim1 = ((w1_sorted_t_target @ w1_sorted_t_ref.T)**2).sum()
        w1_cos_sim.append(cos_sim1)

        w2_sorted_target = w2_sorted[i]
        idx_mask = (mask_range == i)
        w2_sorted_ref = w2_sorted[~idx_mask]
        cos_sim2 = ((w2_sorted_ref @ w2_sorted_target)**2).sum()
        w2_cos_sim.append(cos_sim2)

    w1_cos_dis = np.array(w1_cos_sim)
    w2_cos_dis = np.array(w2_cos_sim)

    x = w1_sorted.T[:,0]
    y = w1_sorted.T[:,1]
    w_norm = np.sqrt(x**2 + y**2)

    wtw = (w2_sorted @ w1_sorted)
    # wtw = (w_sorted.T @ w_sorted)

    wtw_clone = wtw.copy()
    thres = np.percentile(wtw_clone, [99])
    wtw_clone[wtw_clone<thres] = 0.0

    fig, ax = plt.subplots(1,5, figsize=((13,3)))

    ax[0].plot(w_norm) # ||Wi||
    ax[1].plot(b2_sorted) # b    
    ax[2].plot(w1_cos_dis) # cos sim
    pos1 = ax[3].imshow(wtw)
    pos2 = ax[4].imshow(wtw_clone)

    ax[0].title.set_text(f"||$W_{'i'}$||")
    ax[1].title.set_text(f"b")
    ax[2].title.set_text(f"cos sim")
    ax[3].title.set_text(f"$W^TW$")
    ax[4].title.set_text(f"$W^TW$, top1%")

    fig.colorbar(pos1)
    fig.colorbar(pos2)

    plt.tight_layout()
    # plt.savefig("img.png")
    plt.show()
../_images/b1e6899e48e899b415c61c94433ed3cabd428916a702a6878f2ce8bb2cb765f6.png



Exp1.4. Result#

  • x axis of all graphs here is importance ranks of features, 0th as the highest and 48th as the lowest.

  • higher ||Wi|| as its importance is higher.

  • larger b as importance is reduced.

  • cos sim: everything is around 0.

  • \(W^TW\) shows a higher similarity as the feature importance is higher.

  • Too bad, I did not set up model hyperparameters properly such that outputs of ||Wi||, b, and cos sim are slightly diferent in unit (see y-axis of each) than the Anthropic’s outcome.



pos = plt.imshow(wtw_clone[:5,:5])
plt.colorbar(pos)
plt.title(f"$W^TW$, 5x5, top1%")
plt.show()
../_images/3ebf366a05eb4bd310ba19962e54ed1fbbf73c8fe39a63fa6d1ae04bafc4c45f.png