1) NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis (ECCV 2020)

2024. 6. 3. 13:113D Vision/Nerd's NeRF

https://hsejun07.tistory.com/78

 

NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis (번역)

NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis Ben Mildenhall, Pratul P. Srinivasan, Matthew Tancik, Jonathan T. Barron, Ravi Ramamoorthi, Ren Ng Abstract 희소한 입력 뷰 세트를 사용하여 기본 연속 체적 장면 함

hsejun07.tistory.com

 

Goal : 보이지 않는 뷰의 이미지를 합성

 (a) 카메라 ray를 따라 5D 좌표 (위치 x,y,z 및 뷰 방향 θ, ɸ)를 샘플링

 (b) 좌표값들을 MLP에 넣어 컬러 RGB와 볼륨 밀도 σ를 생산함

 (c) 볼륨 렌더링 기술을 사용하여 이들 값들을 이미지로 합성함

 (d) 렌더링 함수가 미분가능함으로 장면 표현을 최적화함

 

MLP Networks

MLP 네트워크가 볼륨 밀도 σ를 위치 x의 함수로만 예측하도록 제한하는 동시에 RGB 색상 c를 위치와 뷰 방향의 함수로 예측할 수 있도록 제한하여 표현이 일관되도록 권장

MLP F_Θ 첫번째는 입력 3D 좌표 x를 8개의 fully-connected layers (ReLU 활성화함수 및 레이어 당 256개 채널 사용)로 처리하고 σ 및 256차원 피쳐 벡터를 출력

이 피쳐 벡터는 카메라 ray의 뷰 방향과 연결되고 뷰 종속 RGB 색상을 출력하는 하나의 추가 fully-connected layer (ReLU 활성화함수 및 128개 채널 사용)로 전달됨

class NeRF(nn.Module):
	def __init__(self, D=8, W=256, input_ch=3, input_ch_view=3, output_ch=4, skip=[4], use_viewdirs=False):
		super(NeRF, self).__init__()
		self.D = D
		self.W = W
		self.input_ch = input_ch
		self.input_ch_views = input_ch_views
		self.skips = skips
		self.use_viewdirs = use_viewdirs
		self.pts_linears = nn.ModuleList(
        	[nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)])
		self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])
        
		if use_viewdirs:
			self.feature_linear = nn.Linear(W, W)
			self.alpha_linear = nn.Linear(W, 1)
			self.rgb_linear = nn.Linear(W//2, 3)
		else:
			self.output_linear = nn.Linear(W, output_ch)
            
	def forward(self, x):
		input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
		h = input_pts
		for i, l in enumerate(self.pts_linears):
			h = self.pts_linears[i](h)
			h = F.relu(h)
			if i in self.skips:
				h = torch.cat([input_pts, h], -1)
                
		if self.use_viewdirs:
			alpha = self.alpha_linear(h)
			feature = self.feature_linear(h)
			h = torch.cat([feature, input_views], -1)
            
			for i, l in enumerate(self.views_linears):
				h = self.view_linears[i](h)
				h = F.relu(h)
			rgb = self.rgb_linear(h)
			outputs = torch.cat([rgb, alpha], -1)
		else:
			outputs = self.output_linear(h)
            
		return outputs

 

Volume Rendering with Radiance Fields

고전적인 볼륨 렌더링의 원리를 사용하여 장면을 통과하는 ray의 색상을 렌더링

 

def render_rays(ray_batch, network_fn, network_query_fn, N_samples, retraw=False, lindisp=False, perturb=0, N_importance=0, netowrk_fine=None, white_bkgd=False, raw_noise_std=0., verbose=False, pytest=False):
	N_rays = ray_batch.shape[0]
    rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6]
    viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None
    bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2])
    near, far = bounds[...,0], bounds[...,1]
    
    t_vals = torch.linspace(0., 1., steps=N_samples)
    if not lindisp:
    	z_vals = near * (1.-t_bals) + far * (t_vals)
    else:
    	z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))
    z_vals = z_vals.expand([N_rays, N_samples])
    
    if perturb > 0:
    	mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
        upper = torch.cat([mids, z_vals[...,-1:]], -1)
        lower = torch.cat([z_vals[...,:1], mids], -1)
        t_rand = torch.rand(z_vals.shape)
        
        if pytest:
        	np.random.seed(0)
            t_rand = np.random.rand(*list(z_vals.shape))
            t_rand = torch.Tensor(t_rand)
            
    	z_vals = lower + (upper - lower) * t_rand
    pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]
    
    raw = network_query_fn(pts, viewdirs, network_fn)
    rgb_map, disp_map, acc_map weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd. pytest=pytest):
    
    if N_importance > 0:
    	rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map
        
        z_vals_mid = .6 * (z_vals[...,1:] + z_vals[...,:-1])
        z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb=0.), pytest=pytest)
        z_samples = z_samples.detach()
        
        z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
        pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]
        
        run_fn = network_fn if network_fine is None else network_fine
        raw = network_query_fn(pts, viewdirs, run_fn)
        rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
        
   ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map}
    if retraw:
        ret['raw'] = raw
    if N_importance > 0:
        ret['rgb0'] = rgb_map_0
        ret['disp0'] = disp_map_0
        ret['acc0'] = acc_map_0
        ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False)

    for k in ret:
        if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
            print(f"! [Numerical Error] {k} contains nan or inf.")

    return ret

 

Optimizing a Neural Radiance Field

1) Positional encoding

 

class Embedder:
	def __init__(self, **kwargs):
    	self.kwargs = kwargs
        self.create_embedding_fn()
        
    def create_embedding_fn(self):
    	embed_fns = []
        d = self.kwargs['input_dims']
        out_dim = 0
        if self.kwargs['include_input']:
        	embed_fns.append(lambda x: x)
            out_dim += d
    	
        max_freq = self.kwargs['max_freq_log2']
        N_freq = self.kwargs['num_freqs']
        
        if self.kwargs['log_sampling']:
        	freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
        else:
        	freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
        
        for freq in freq_bands:
        	for p_fn in self.kwargs['periodic_fns']:
            	embed_fns.append(lambda x, p_fn=p_fn, freq=freq, : p_fn(x * freq))
                out_dim += d
        
        self.embed_fns = embed_fns
        self.out_dim = out_dim
        
    def embed(self, inputs):
    	return torch.cat([fn(inputs) for fn in self.embed_fns], -1)

 

2) Hierarchical volume sampling

 

def create_nerf(args):
	output_ch = 5 if args.N_importance > 0 else 4
    skips = [4]
    model = NeRF(D=args.netdepth, W=args.netwidth, input_ch=input_ch, output_ch=output_ch, skips=skips, input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
    grad_vars = list(model.parameters())
    
    model_fine = None
    if args.N_importance > 0:
    	model_fine = NeRF(D=args.netdepth, W=args.netwidth, input_ch=input_ch, output_ch=output_ch, skips=skips, input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
        grad_vars += list(model_fine.parameters())
        
    network_query_fn = lambda inputs, viewdirs, network_fn: run_network(inputs, viewdirs, network_fn, embed_fn=embed_fn, embeddirs_fn=embeddirs_fn, netchunk=args.netchunk)
    
    optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999)
    
    start = 0
    basedir = args.basedir
    expname = args.expname

    if args.ft_path is not None and args.ft_path!='None':
        ckpts = [args.ft_path]
    else:
        ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f]

    print('Found ckpts', ckpts)
    if len(ckpts) > 0 and not args.no_reload:
        ckpt_path = ckpts[-1]
        print('Reloading from', ckpt_path)
        ckpt = torch.load(ckpt_path)

        start = ckpt['global_step']
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])

        model.load_state_dict(ckpt['network_fn_state_dict'])
        if model_fine is not None:
            model_fine.load_state_dict(ckpt['network_fine_state_dict'])

    render_kwargs_train = {
        'network_query_fn' : network_query_fn,
        'perturb' : args.perturb,
        'N_importance' : args.N_importance,
        'network_fine' : model_fine,
        'N_samples' : args.N_samples,
        'network_fn' : model,
        'use_viewdirs' : args.use_viewdirs,
        'white_bkgd' : args.white_bkgd,
        'raw_noise_std' : args.raw_noise_std,
    }

    if args.dataset_type != 'llff' or args.no_ndc:
        print('Not ndc!')
        render_kwargs_train['ndc'] = False
        render_kwargs_train['lindisp'] = args.lindisp

    render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
    render_kwargs_test['perturb'] = False
    render_kwargs_test['raw_noise_std'] = 0.

    return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer