We introduce Plenoxels (plenoptic voxels), a system for photorealistic view synthesis. Plenoxels represent a scene as a sparse 3D grid with spherical harmonics. This representation can be optimized from calibrated images via gradient methods and regularization without any neural components. On standard, benchmark tasks, Plenoxels are optimized two orders of magnitude faster than Neural Radiance Fields with no loss in visual quality.
Source: Plenoxels: Radiance Fields without Neural Networks (2021-12-09). See: paper link.
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
def eval_spherical_function(k, d):
x, y, z = d[..., 0:1], d[..., 1:2], d[..., 2:3]
# Modified from https://github.com/google/spherical-harmonics/blob/master/sh/spherical_harmonics.cc
return 0.282095 * k[..., 0] + \
- 0.488603 * y * k[..., 1] + 0.488603 * z * k[..., 2] - 0.488603 * x * k[..., 3] + \
(1.092548 * x * y * k[..., 4] - 1.092548 * y * z * k[..., 5] + 0.315392 * (2.0 * z * z - x * x - y * y) * k[
..., 6] + -1.092548 * x * z * k[..., 7] + 0.546274 * (x * x - y * y) * k[..., 8])
class NerfModel(nn.Module):
def __init__(self, N=256, scale=1.5):
"""
:param N
:param scale: The maximum absolute value among all coordinates for objects in the scene
"""
super(NerfModel, self).__init__()
self.voxel_grid = nn.Parameter(torch.ones((N, N, N, 27 + 1)) / 100)
self.scale = scale
self.N = N
def forward(self, x, d):
color = torch.zeros_like(x)
sigma = torch.zeros((x.shape[0]), device=x.device)
mask = (x[:, 0].abs() < self.scale) & (x[:, 1].abs() < self.scale) & (x[:, 2].abs() < self.scale)
idx = (x[mask] / (2 * self.scale / self.N) + self.N / 2).long().clip(0, self.N - 1)
tmp = self.voxel_grid[idx[:, 0], idx[:, 1], idx[:, 2]]
sigma[mask], k = torch.nn.functional.relu(tmp[:, 0]), tmp[:, 1:]
color[mask] = eval_spherical_function(k.reshape(-1, 3, 9), d[mask])
return color, sigma
@torch.no_grad()
def test(hn, hf, dataset, chunk_size=10, img_index=0, nb_bins=192, H=400, W=400):
ray_origins = dataset[img_index * H * W: (img_index + 1) * H * W, :3]
ray_directions = dataset[img_index * H * W: (img_index + 1) * H * W, 3:6]
data = []
for i in range(int(np.ceil(H / chunk_size))):
ray_origins_ = ray_origins[i * W * chunk_size: (i + 1) * W * chunk_size].to(device)
ray_directions_ = ray_directions[i * W * chunk_size: (i + 1) * W * chunk_size].to(device)
regenerated_px_values = render_rays(model, ray_origins_, ray_directions_, hn=hn, hf=hf, nb_bins=nb_bins)
data.append(regenerated_px_values)
img = torch.cat(data).data.cpu().numpy().reshape(H, W, 3)
plt.figure()
plt.imshow(img)
plt.savefig(f'Imgs/img_{img_index}.png', bbox_inches='tight')
plt.close()
def compute_accumulated_transmittance(alphas):
accumulated_transmittance = torch.cumprod(alphas, 1)
return torch.cat((torch.ones((accumulated_transmittance.shape[0], 1), device=alphas.device),
accumulated_transmittance[:, :-1]), dim=-1)
def render_rays(nerf_model, ray_origins, ray_directions, hn=0, hf=0.5, nb_bins=192):
device = ray_origins.device
t = torch.linspace(hn, hf, nb_bins, device=device).expand(ray_origins.shape[0], nb_bins)
# Perturb sampling along each ray.
mid = (t[:, :-1] + t[:, 1:]) / 2.
lower = torch.cat((t[:, :1], mid), -1)
upper = torch.cat((mid, t[:, -1:]), -1)
u = torch.rand(t.shape, device=device)
t = lower + (upper - lower) * u # [batch_size, nb_bins]
delta = torch.cat((t[:, 1:] - t[:, :-1], torch.tensor([1e10], device=device).expand(ray_origins.shape[0], 1)), -1)
x = ray_origins.unsqueeze(1) + t.unsqueeze(2) * ray_directions.unsqueeze(1) # [batch_size, nb_bins, 3]
ray_directions = ray_directions.expand(nb_bins, ray_directions.shape[0], 3).transpose(0, 1)
colors, sigma = nerf_model(x.reshape(-1, 3), ray_directions.reshape(-1, 3))
colors = colors.reshape(x.shape)
sigma = sigma.reshape(x.shape[:-1])
alpha = 1 - torch.exp(-sigma * delta) # [batch_size, nb_bins]
weights = compute_accumulated_transmittance(1 - alpha).unsqueeze(2) * alpha.unsqueeze(2)
c = (weights * colors).sum(dim=1) # Pixel values
weight_sum = weights.sum(-1).sum(-1) # Regularization for white background
return c + 1 - weight_sum.unsqueeze(-1)
def train(nerf_model, optimizer, scheduler, data_loader, device='cpu', hn=0, hf=1, nb_epochs=int(1e5),
nb_bins=192):
training_loss = []
for _ in range(nb_epochs):
for batch in tqdm(data_loader):
ray_origins = batch[:, :3].to(device)
ray_directions = batch[:, 3:6].to(device)
ground_truth_px_values = batch[:, 6:].to(device)
regenerated_px_values = render_rays(nerf_model, ray_origins, ray_directions, hn=hn, hf=hf, nb_bins=nb_bins)
loss = torch.nn.functional.mse_loss(ground_truth_px_values, regenerated_px_values)
optimizer.zero_grad()
loss.backward()
optimizer.step()
training_loss.append(loss.item())
scheduler.step()
return training_loss
if __name__ == "__main__":
device = 'cuda'
training_dataset = torch.from_numpy(np.load('training_data.pkl', allow_pickle=True))
testing_dataset = torch.from_numpy(np.load('testing_data.pkl', allow_pickle=True))
model = NerfModel(N=256).to(device)
model_optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
scheduler = torch.optim.lr_scheduler.MultiStepLR(model_optimizer, milestones=[2, 4, 8], gamma=0.5)
data_loader = DataLoader(training_dataset, batch_size=2048, shuffle=True)
train(model, model_optimizer, scheduler, data_loader, nb_epochs=16, device=device, hn=2, hf=6, nb_bins=192)
for img_index in [0, 60, 120, 180]:
test(2, 6, testing_dataset, img_index=img_index, nb_bins=192, H=400, W=400)
python implementation Plenoxels: Radiance Fields without Neural Networks in 100 lines
2021-12-09