We present an information-theoretic regularization technique for few-shot novel view synthesis based on neural implicit representation. The proposed approach minimizes potential reconstruction inconsistency that happens due to insufficient viewpoints by imposing the entropy constraint of the density in each ray. In addition, to alleviate the potential degenerate issue when all training images are acquired from almost redundant viewpoints, we further incorporate the spatially smoothness constraint into the estimated images by restricting information gains from a pair of rays with slightly different viewpoints. The main idea of our algorithm is to make reconstructed scenes compact along individual rays and consistent across rays in the neighborhood. The proposed regularizers can be plugged into most of existing neural volume rendering techniques based on NeRF in a straightforward way. Despite its simplicity, we achieve consistently improved performance compared to existing neural view synthesis methods by large margins on multiple standard benchmarks.
Source: InfoNeRF: Ray Entropy Minimization for Few-Shot Neural Volume Rendering (2021-12-31). See: paper link.
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
@torch.no_grad()
def test(hn, hf, dataset, chunk_size=10, img_index=0, nb_bins=192, H=400, W=400):
ray_origins = dataset[img_index * H * W: (img_index + 1) * H * W, :3]
ray_directions = dataset[img_index * H * W: (img_index + 1) * H * W, 3:6]
data = []
for i in range(int(np.ceil(H / chunk_size))):
ray_origins_ = ray_origins[i * W * chunk_size: (i + 1) * W * chunk_size].to(device)
ray_directions_ = ray_directions[i * W * chunk_size: (i + 1) * W * chunk_size].to(device)
regenerated_px_values, _ = render_rays(model, ray_origins_, ray_directions_, hn=hn, hf=hf, nb_bins=nb_bins)
data.append(regenerated_px_values)
img = torch.cat(data).data.cpu().numpy().reshape(H, W, 3)
plt.figure()
plt.imshow(img)
plt.savefig(f'novel_views/img_{img_index}.png', bbox_inches='tight')
plt.close()
class NerfModel(nn.Module):
def __init__(self, embedding_dim_pos=20, embedding_dim_direction=8, hidden_dim=128):
super(NerfModel, self).__init__()
self.block1 = nn.Sequential(nn.Linear(embedding_dim_pos * 3, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), )
self.block2 = nn.Sequential(nn.Linear(embedding_dim_pos * 3 + hidden_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim + 1), )
self.block3 = nn.Sequential(nn.Linear(embedding_dim_direction * 3 + hidden_dim, hidden_dim // 2), nn.ReLU(), )
self.block4 = nn.Sequential(nn.Linear(hidden_dim // 2, 3), nn.Sigmoid(), )
self.embedding_dim_pos = embedding_dim_pos
self.embedding_dim_direction = embedding_dim_direction
self.relu = nn.ReLU()
@staticmethod
def positional_encoding(x, L):
out = torch.empty(x.shape[0], x.shape[1] * 2 * L, device=x.device)
for i in range(x.shape[1]):
for j in range(L):
out[:, i * (2 * L) + 2 * j] = torch.sin(2 ** j * x[:, i])
out[:, i * (2 * L) + 2 * j + 1] = torch.cos(2 ** j * x[:, i])
return out
def forward(self, o, d):
emb_x = self.positional_encoding(o, self.embedding_dim_pos // 2)
emb_d = self.positional_encoding(d, self.embedding_dim_direction // 2)
h = self.block1(emb_x)
tmp = self.block2(torch.cat((h, emb_x), dim=1))
h, sigma = tmp[:, :-1], self.relu(tmp[:, -1])
h = self.block3(torch.cat((h, emb_d), dim=1))
c = self.block4(h)
return c, sigma
def compute_accumulated_transmittance(alphas):
accumulated_transmittance = torch.cumprod(alphas, 1)
return torch.cat((torch.ones((accumulated_transmittance.shape[0], 1), device=alphas.device),
accumulated_transmittance[:, :-1]), dim=-1)
def render_rays(nerf_model, ray_origins, ray_directions, hn=0, hf=0.5, nb_bins=192, T=0.1):
device = ray_origins.device
t = torch.linspace(hn, hf, nb_bins, device=device).expand(ray_origins.shape[0], nb_bins)
# Perturb sampling along each ray.
mid = (t[:, :-1] + t[:, 1:]) / 2.
lower = torch.cat((t[:, :1], mid), -1)
upper = torch.cat((mid, t[:, -1:]), -1)
u = torch.rand(t.shape, device=device)
t = lower + (upper - lower) * u # [batch_size, nb_bins]
delta = torch.cat((t[:, 1:] - t[:, :-1], torch.tensor([1e10], device=device).expand(ray_origins.shape[0], 1)), -1)
x = ray_origins.unsqueeze(1) + t.unsqueeze(2) * ray_directions.unsqueeze(1) # [batch_size, nb_bins, 3]
ray_directions = ray_directions.expand(nb_bins, ray_directions.shape[0], 3).transpose(0, 1)
colors, sigma = nerf_model(x.reshape(-1, 3), ray_directions.reshape(-1, 3))
colors = colors.reshape(x.shape)
sigma = sigma.reshape(x.shape[:-1])
alpha = 1 - torch.exp(-sigma * delta) # [batch_size, nb_bins]
weights = compute_accumulated_transmittance(1 - alpha).unsqueeze(2) * alpha.unsqueeze(2)
prob = alpha / (alpha.sum(1).unsqueeze(1) + 1e-10)
mask = alpha.sum(1).unsqueeze(1) > T
regularization = -1 * prob * torch.log2(prob + 1e-10)
regularization = (regularization * mask).sum(1).mean()
c = (weights * colors).sum(dim=1) # Pixel values
# Regularization for white background
weight_sum = weights.sum(-1).sum(-1)
return c + 1 - weight_sum.unsqueeze(-1), regularization
def train(nerf_model, optimizer, scheduler, data_loader, device='cpu', hn=0, hf=1, nb_epochs=int(1e5),
nb_bins=192, lambda_mul=4, T=0.1):
training_loss = []
for _ in range(nb_epochs):
for batch in tqdm(data_loader):
ray_origins = batch[:, :3].to(device)
ray_directions = batch[:, 3:6].to(device)
ground_truth_px_values = batch[:, 6:].to(device)
regenerated_px_values, regularization = render_rays(nerf_model, ray_origins, ray_directions, hn=hn, hf=hf,
nb_bins=nb_bins, T=T)
loss = ((ground_truth_px_values - regenerated_px_values) ** 2).sum()
training_loss.append(loss.item())
loss = loss + lambda_mul * regularization
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
return training_loss
if __name__ == '__main__':
device = 'cuda'
training_dataset = torch.from_numpy(np.load('training_data.pkl', allow_pickle=True))
testing_dataset = torch.from_numpy(np.load('testing_data.pkl', allow_pickle=True))
# Only keep four images from the training dataset
training_dataset = np.concatenate((training_dataset[0 * 400 * 400:(0 + 1) * 400 * 400],
training_dataset[30 * 400 * 400:(30 + 1) * 400 * 400],
training_dataset[60 * 400 * 400:(60 + 1) * 400 * 400],
training_dataset[90 * 400 * 400:(90 + 1) * 400 * 400]))
model = NerfModel(hidden_dim=256).to(device)
model_optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(model_optimizer, milestones=[2, 4, 8], gamma=0.5)
data_loader = DataLoader(training_dataset, batch_size=1024, shuffle=True)
train(model, model_optimizer, scheduler, data_loader, nb_epochs=32, device=device, hn=2, hf=6, nb_bins=192, T=0.1)
for img_index in range(200):
test(2, 6, testing_dataset, img_index=img_index, nb_bins=192, H=400, W=400)
python implementation InfoNeRF: Ray Entropy Minimization for Few-Shot Neural Volume Rendering in 100 lines
2021-12-31