Code for NeRF--: Neural Radiance Fields Without Known Camera Parameters

Implementation in 100 lines of Python Β· NeRF--
View on GitHub β†’
Abstract (original paper)

Considering the problem of novel view synthesis (NVS) from only a set of 2D images, we simplify the training process of Neural Radiance Field (NeRF) on forward-facing scenes by removing the requirement of known or pre-computed camera parameters, including both intrinsics and 6DoF poses. To this end, we propose NeRF$--$, with three contributions: First, we show that the camera parameters can be jointly optimised as learnable parameters with NeRF training, through a photometric reconstruction; Second, to benchmark the camera parameter estimation and the quality of novel view renderings, we introduce a new dataset of path-traced synthetic scenes, termed as Blender Forward-Facing Dataset (BLEFF); Third, we conduct extensive analyses to understand the training behaviours under various camera motions, and show that in most scenarios, the joint optimisation pipeline can recover accurate camera parameters and achieve comparable novel view synthesis quality as those trained with COLMAP pre-computed camera parameters. Our code and data are available at https://nerfmm.active.vision.

Source: NeRF--: Neural Radiance Fields Without Known Camera Parameters (2021-02-14).

Code

NeRF-- in 100 lines (Python)

import torch.nn as nn
import torch.utils.data
from tqdm import tqdm
import glob
from matplotlib.image import imread
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
import torch


@torch.no_grad()
def test(model, camera_intrinsics, camera_extrinsics, hn, hf, images, chunk_size=10, img_index=0, nb_bins=192, H=400,
         W=400):
    ray_origins, ray_directions, _ = sample_batch(camera_extrinsics, camera_intrinsics, images, None, H, W,
                                                  img_index=img_index, sample_all=True)
    data = []
    for i in range(int(np.ceil(H / chunk_size))):
        ray_origins_ = ray_origins[i * W * chunk_size: (i + 1) * W * chunk_size].to(camera_intrinsics.device)
        ray_directions_ = ray_directions[i * W * chunk_size: (i + 1) * W * chunk_size].to(camera_intrinsics.device)
        regenerated_px_values = render_rays(model, ray_origins_, ray_directions_, hn=hn, hf=hf, nb_bins=nb_bins)
        data.append(regenerated_px_values)
    img = torch.cat(data).data.cpu().numpy().reshape(H, W, 3)
    plt.imshow(img)
    plt.savefig(f'Imgs/novel_view.png', bbox_inches='tight')
    plt.close()


class NerfModel(nn.Module):
    def __init__(self, embedding_dim_pos=10, embedding_dim_direction=4, hidden_dim=128):
        super(NerfModel, self).__init__()

        self.block1 = nn.Sequential(nn.Linear(embedding_dim_pos * 6 + 3, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), )
        self.block2 = nn.Sequential(nn.Linear(embedding_dim_pos * 6 + hidden_dim + 3, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim + 1), )
        self.block3 = nn.Sequential(nn.Linear(embedding_dim_direction * 6 + hidden_dim + 3, hidden_dim // 2),
                                    nn.ReLU(), )
        self.block4 = nn.Sequential(nn.Linear(hidden_dim // 2, 3), nn.Sigmoid(), )

        self.embedding_dim_pos = embedding_dim_pos
        self.embedding_dim_direction = embedding_dim_direction
        self.relu = nn.ReLU()

    @staticmethod
    def positional_encoding(x, L):
        out = [x]
        for j in range(L):
            out.append(torch.sin(2 ** j * x))
            out.append(torch.cos(2 ** j * x))
        return torch.cat(out, dim=1)

    def forward(self, o, d):
        emb_x = self.positional_encoding(o, self.embedding_dim_pos)
        emb_d = self.positional_encoding(d, self.embedding_dim_direction)
        tmp = self.block2(torch.cat((self.block1(emb_x), emb_x), dim=1))
        h, sigma = tmp[:, :-1], self.relu(tmp[:, -1])
        c = self.block4(self.block3(torch.cat((h, emb_d), dim=1)))
        return c, sigma


def compute_accumulated_transmittance(alphas):
    accumulated_transmittance = torch.cumprod(alphas, 1)
    return torch.cat((torch.ones((accumulated_transmittance.shape[0], 1), device=alphas.device),
                      accumulated_transmittance[:, :-1]), dim=-1)


def render_rays(nerf_model, ray_origins, ray_directions, hn=0, hf=0.5, nb_bins=192):
    device = ray_origins.device
    t = torch.linspace(hn, hf, nb_bins, device=device).expand(ray_origins.shape[0], nb_bins)
    mid = (t[:, :-1] + t[:, 1:]) / 2.
    lower = torch.cat((t[:, :1], mid), -1)
    upper = torch.cat((mid, t[:, -1:]), -1)
    u = torch.rand(t.shape, device=device)  # Perturb sampling along each ray.
    t = lower + (upper - lower) * u  # [batch_size, nb_bins]
    delta = torch.cat((t[:, 1:] - t[:, :-1], torch.tensor([1e10], device=device).expand(ray_origins.shape[0], 1)), -1)

    x = ray_origins.unsqueeze(1) + t.unsqueeze(2) * ray_directions.unsqueeze(1)  # [batch_size, nb_bins, 3]
    ray_directions = ray_directions.expand(nb_bins, ray_directions.shape[0], 3).transpose(0, 1)

    colors, sigma = nerf_model(x.reshape(-1, 3), ray_directions.reshape(-1, 3))
    alpha = 1 - torch.exp(-sigma.reshape(x.shape[:-1]) * delta)  # [batch_size, nb_bins]
    weights = compute_accumulated_transmittance(1 - alpha).unsqueeze(2) * alpha.unsqueeze(2)
    return (weights * colors.reshape(x.shape)).sum(dim=1)  # Pixel values


def train(nerf_model, optimizers, schedulers, training_images, camera_extrinsics, camera_intrinsics, batch_size,
          nb_epochs, hn=0., hf=1., nb_bins=192):
    H, W = training_images.shape[1:3]

    training_loss = []
    for _ in tqdm(range(nb_epochs)):
        ids = np.arange(training_images.shape[0])
        np.random.shuffle(ids)
        for img_index in ids:
            rays_o, rays_d, samples_idx = sample_batch(camera_extrinsics, camera_intrinsics, training_images,
                                                       batch_size, H, W, img_index=img_index)
            gt_px_values = torch.from_numpy(training_images[samples_idx]).to(camera_intrinsics.device)
            regenerated_px_values = render_rays(nerf_model, rays_o, rays_d, hn=hn, hf=hf, nb_bins=nb_bins)
            loss = ((gt_px_values - regenerated_px_values) ** 2).sum()

            for optimizer in optimizers:
                optimizer.zero_grad()
            loss.backward()
            for optimizer in optimizers:
                optimizer.step()
            training_loss.append(loss.item())
        for scheduler in schedulers:
            scheduler.step()
    return training_loss


def initialize_camera_parameters(images, device='cpu'):
    camera_intrinsics = torch.ones(1, device=device, requires_grad=True)
    camera_extrinsics = torch.zeros((images.shape[0], 6), device=device, dtype=torch.float32, requires_grad=True)
    return camera_intrinsics, camera_extrinsics


def load_images(data_path):
    image_paths = glob.glob(data_path)
    images = None
    for i, image_path in enumerate(image_paths):
        img = np.expand_dims(imread(image_path), 0)
        images = np.concatenate((images, img)) if images is not None else img
    return images


def get_ndc_rays(H, W, focal, rays_o, rays_d, near=1.):
    # We shift o to the ray’s intersection with the near plane at z = βˆ’n (before the NDC conversion)
    t = -(near + rays_o[..., 2]) / rays_d[..., 2]
    rays_o = rays_o + t[..., None] * rays_d

    rays_o = torch.stack([- focal / W / 2. * rays_o[..., 0] / rays_o[..., 2],
                          - focal / H / 2. * rays_o[..., 1] / rays_o[..., 2],
                          1. + 2. * near / rays_o[..., 2]], -1)  # Eq 25 https://arxiv.org/pdf/2003.08934.pdf
    rays_d = torch.stack([- focal / W / 2. * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2]),
                          - focal / H / 2. * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2]),
                          - 2. * near / rays_o[..., 2]], -1)  # Eq 26 https://arxiv.org/pdf/2003.08934.pdf
    return rays_o, rays_d


def sample_batch(camera_extrinsics, camera_intrinsics, images, batch_size, H, W, img_index=0, sample_all=False):
    if sample_all:
        image_indices = (torch.zeros(W * H) + img_index).type(torch.long)
        u, v = np.meshgrid(np.linspace(0, W - 1, W, dtype=int), np.linspace(0, H - 1, H, dtype=int))
        u = torch.from_numpy(u.reshape(-1)).to(camera_intrinsics.device)
        v = torch.from_numpy(v.reshape(-1)).to(camera_intrinsics.device)
    else:
        image_indices = (torch.zeros(batch_size) + img_index).type(torch.long)  # Sample random images
        u = torch.randint(W, (batch_size,), device=camera_intrinsics.device)  # Sample random pixels
        v = torch.randint(H, (batch_size,), device=camera_intrinsics.device)

    focal = camera_intrinsics[0] ** 2 * W
    t = camera_extrinsics[img_index, :3]
    r = camera_extrinsics[img_index, -3:]

    # Creating the c2w matrix, Section 4.1 from the paper
    phi_skew = torch.stack([torch.cat([torch.zeros(1, device=r.device), -r[2:3], r[1:2]]),
                            torch.cat([r[2:3], torch.zeros(1, device=r.device), -r[0:1]]),
                            torch.cat([-r[1:2], r[0:1], torch.zeros(1, device=r.device)])], dim=0)
    alpha = r.norm() + 1e-15
    R = torch.eye(3, device=r.device) + (torch.sin(alpha) / alpha) * phi_skew + (
            (1 - torch.cos(alpha)) / alpha ** 2) * (phi_skew @ phi_skew)
    c2w = torch.cat([R, t.unsqueeze(1)], dim=1)
    c2w = torch.cat([c2w, torch.tensor([[0., 0., 0., 1.]], device=c2w.device)], dim=0)

    rays_d_cam = torch.cat([((u.to(camera_intrinsics.device) - .5 * W) / focal).unsqueeze(-1),
                            (-(v.to(camera_intrinsics.device) - .5 * H) / focal).unsqueeze(-1),
                            - torch.ones_like(u).unsqueeze(-1)], dim=-1)
    rays_d_world = torch.matmul(c2w[:3, :3].view(1, 3, 3), rays_d_cam.unsqueeze(2)).squeeze(2)
    rays_o_world = c2w[:3, 3].view(1, 3).expand_as(rays_d_world)
    rays_o_world, rays_d_world = get_ndc_rays(H, W, focal, rays_o=rays_o_world, rays_d=rays_d_world)
    return rays_o_world, F.normalize(rays_d_world, p=2, dim=1), (image_indices, v.cpu(), u.cpu())


if __name__ == "__main__":
    device = 'cuda'
    nb_epochs = int(1e4)

    training_images = load_images("fern/images_4/*.png")
    camera_intrinsics, camera_extrinsics = initialize_camera_parameters(training_images, device=device)
    batch_size = 1024

    # Part 1
    model = NerfModel(hidden_dim=256).to(device)
    model_optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    optimizer_camera_parameters = torch.optim.Adam({camera_extrinsics}, lr=0.0009)
    optimizer_focal = torch.optim.Adam({camera_intrinsics}, lr=0.001)
    scheduler_model = torch.optim.lr_scheduler.MultiStepLR(
        model_optimizer, [10 * (i + 1) for i in range(nb_epochs // 10)], gamma=0.9954)
    scheduler_camera_parameters = torch.optim.lr_scheduler.MultiStepLR(
        optimizer_camera_parameters, [100 * (i + 1) for i in range(nb_epochs // 100)], gamma=0.81)
    scheduler_focal = torch.optim.lr_scheduler.MultiStepLR(
        optimizer_focal, [100 * (i + 1) for i in range(nb_epochs // 100)], gamma=0.9)
    train(model, [model_optimizer, optimizer_camera_parameters, optimizer_focal],
          [scheduler_model, scheduler_camera_parameters, scheduler_focal], training_images, camera_extrinsics,
          camera_intrinsics, batch_size, nb_epochs, hn=0., hf=1., nb_bins=192)

    # Part 2
    model = NerfModel(hidden_dim=256).to(device)
    model_optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler_model = torch.optim.lr_scheduler.MultiStepLR(
        model_optimizer, [10 * (i + 1) for i in range(nb_epochs // 10)], gamma=0.9954)
    train(model, [model_optimizer], [scheduler_model], training_images, camera_extrinsics, camera_intrinsics,
          batch_size, nb_epochs, hn=0., hf=1., nb_bins=192)

    # Test: interpolation between two images
    test(model, camera_intrinsics, (.5 * camera_extrinsics[0] + .5 * camera_extrinsics[1]).unsqueeze(0), 0., 1.,
         training_images, img_index=0, nb_bins=192, H=training_images.shape[1], W=training_images.shape[2])

python implementation NeRF--: Neural Radiance Fields Without Known Camera Parameters in 100 lines