주뇽's 저장소

NeRF 본문

ComputerVision/NeRF

NeRF

뎁쭌 2023. 7. 1. 01:58
728x90
반응형

NeRF

NeRF (Neural Radiance Fields)는 신경망을 사용하여 3D 공간의 시각적인 표현을 학습하는 방법이다. NeRF는 3D 장면의 형상과 텍스처 정보를 캡처하여 다양한 관점에서의 렌더링을 수행할 수 있다.

특징

  1. 3D 시각화: NeRF는 3D 공간의 형상과 텍스처를 고해상도로 재구성하여 현실적이고 자연스러운 시각화를 제공하며 이를 통해 실제 세계와 유사한 시각적인 품질을 달성할 수 있다.
  2. 시야 변화에 대한 일관성: NeRF는 여러 각도에서의 뷰를 학습하여 3D 공간을 모델링하므로 시야 변화에 따라 일관된 렌더링 결과를 제공한다. 이는 회전, 이동 등의 변환에 대해 시각적인 일관성을 유지할 수 있다.
  3. 장면의 광원 모델링: NeRF는 공간 내의 광원에 대한 모델링도 수행한다. 이는 장면에 따른 조명의 영향을 재현하여 더욱 현실적인 렌더링 결과를 얻을 수 있다.

장점

  1. 고해상도 시각화: NeRF는 고해상도의 3D 시각화를 제공하므로 디테일한 구조와 텍스처를 포착할 수 있다. 이는 시각적으로 더욱 명확하고 생생한 결과물을 얻을 수 있다.
  2. 뷰 인터폴레이션: NeRF는 학습된 3D 모델을 사용하여 다양한 시점에서의 뷰를 생성할 수 있다. 이는 중간 시점에서의 뷰를 보간하여 새로운 시점의 이미지를 생성할 수 있다는 장점을 제공한다.

단점

  1. 계산적인 복잡성: NeRF는 대량의 파라미터와 계산량이 요구된다. 고해상도의 3D 모델을 학습하고 렌더링하기 위해서는 상당한 컴퓨팅 자원과 시간이 필요하다.
  2. 학습 데이터 의존성: NeRF는 큰 규모의 학습 데이터셋이 필요한다. 실제로 다양한 관점에서의 장면을 포착하는 데는 많은 이미지와 해당 이미지에서의 깊이 맵이 필요하다. 이러한 데이터 수집은 비용과 노력이 많이 들어간다
  • 최근에는 Few Shot Learning을 통해 극복하고자 하는 노력들이 있다.
  1. 실시간 렌더링의 한계: NeRF는 실시간 렌더링에는 적합하지 않다. 학습된 모델을 사용하여 새로운 뷰를 생성하는 데에는 상당한 계산 시간이 소요된다.

NeRF는 고해상도의 3D 시각화와 다양한 관점에서의 일관된 렌더링을 제공한다.

NeRF:NeRF: Representing scenes as neural radiance fields for view synthesis 논문을 간단히 구현하는 과제를 수행합니다.

 

import torch.nn as nn

class NerfModel(torch.nn.Module):
  r"""Define a NeRF model comprising three fully connected layers.
  """
  def __init__(self, D=6, filter_size=256, skip=[4], num_encoding_functions=10, num_encoding_dir_functions=4):
    """
    Args:
      D : 중간 layer의 개수 (주황색 화살표 이전의 layer)
      filer_size : layer의 channel 개수
      skip : skip connection을 추가할 layer의 number
      num_encoding_functions : 3d point의 positional embedding freq number
      num_encoding_dir_functions : direction의 positional embedding freq number
    """
    super(NerfModel, self).__init__()
    self.input_ch = 3 + 3 * 2 * num_encoding_functions
    self.input_ch_views = 3 + 3 * 2 *num_encoding_dir_functions
    self.skips = skip
    ############## START CODE HERE ##############
    # 6 layers
    self.linear_input   =    nn.Sequential(nn.Linear(self.input_ch, filter_size),
                                       nn.ReLU(inplace=True))
    
    self.linear_x   =    nn.Sequential(nn.Linear(filter_size, filter_size),
                                       nn.ReLU(inplace=True))
    
    self.linear_skip   =    nn.Sequential(nn.Linear(filter_size+ self.input_ch, filter_size),
                                       nn.ReLU(inplace=True))
    # density
    self.linear_density = nn.Linear(filter_size, 1)

    # color
    self.linear = nn.Linear(filter_size, filter_size)
    self.linear_color = nn.Sequential(nn.Linear(self.input_ch_views + filter_size, filter_size//2),
                                      nn.ReLU(inplace=True),
                                      nn.Linear(filter_size // 2, 3)
                                      )
    
    ############## END CODE HERE ##############
  def forward(self, x):
    input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
    ############## START CODE HERE ##############
    x = self.linear_input(input_pts)
    # print("X Shape is " , x.shape)
    # x = self.linear_x(x)
    
    for i in range(2,7):
      if i == self.skips[0]:
        x = torch.cat([input_pts, x], dim=-1)
        x = self.linear_skip(x)
      else:
        x = self.linear_x(x)
    # density
    alpha = self.linear_density(x)

    # color
    x = self.linear(x)
    x = torch.cat([x, input_views], dim=-1)
    rgb = self.linear_color(x)
    ############## END CODE HERE ##############
    outputs = torch.cat([rgb, alpha], -1)
    return outputs

"""
Parameters for TinyNeRF training
"""

# Number of functions used in the positional encoding (Be sure to update the 
# model if this number changes).
num_encoding_functions = 10
num_encoding_dir_functions = 4
# Specify encoding function.
encode = lambda x: positional_encoding(x, num_encoding_functions=num_encoding_functions)
dir_encode = lambda x: positional_encoding(x, num_encoding_functions=num_encoding_dir_functions)
# Number of depth samples along each ray.
depth_samples_per_ray = 64

# Chunksize (Note: this isn't batchsize in the conventional sense. This only
# specifies the number of rays to be queried in one go. Backprop still happens
# only after all rays from the current "bundle" are queried and rendered).
chunksize = 16384  # Use chunksize of about 4096 to fit in ~1.4 GB of GPU memory.

# Optimizer parameters
lr = 5e-4
num_iters = 20000

# Misc parameters
display_every = 100  # Number of iters after which stats are displayed

"""
Model
""" 
model = NerfModel(D=6, filter_size=256, skip=[4], num_encoding_functions=num_encoding_functions, num_encoding_dir_functions=num_encoding_dir_functions)
model.to(device)

"""
Optimizer
"""
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

"""
Train-Eval-Repeat!
"""

# Seed RNG, for repeatability
seed = 123
torch.manual_seed(seed)
np.random.seed(seed)

# Lists to log metrics etc.
psnrs = []
iternums = []

for i in range(num_iters):

  # Randomly pick an image as the target.
  target_img_idx = np.random.randint(images.shape[0])
  target_img = images[target_img_idx].to(device)
  target_tform_cam2world = tform_cam2world[target_img_idx].to(device)

  # Run one iteration of TinyNeRF and get the rendered RGB image.
  rgb_predicted = run_one_iter_of_tinynerf(height, width, focal_length,
                                           target_tform_cam2world, near_thresh,
                                           far_thresh, depth_samples_per_ray,
                                           encode, dir_encode, get_minibatches)

  # Compute mean-squared error between the predicted and target images. Backprop!
  loss = torch.nn.functional.mse_loss(rgb_predicted, target_img)
  loss.backward()
  optimizer.step()
  optimizer.zero_grad()

  # Display images/plots/stats
  if i % display_every == 0:
    # Render the held-out view
    rgb_predicted = run_one_iter_of_tinynerf(height, width, focal_length,
                                             testpose, near_thresh,
                                             far_thresh, depth_samples_per_ray,
                                             encode, dir_encode, get_minibatches)
    loss = torch.nn.functional.mse_loss(rgb_predicted, testimg)
    print("Loss:", loss.item())
    psnr = -10. * torch.log10(loss)
    
    psnrs.append(psnr.item())
    iternums.append(i)

    plt.figure(figsize=(10, 4))
    plt.subplot(121)
    plt.imshow(rgb_predicted.detach().cpu().numpy())
    plt.title(f"Iteration {i}")
    plt.subplot(122)
    plt.plot(iternums, psnrs)
    plt.title("PSNR")
    plt.show()

print('Done!')

 

'ComputerVision > NeRF' 카테고리의 다른 글

4장 심층학습(deep learning)  (0) 2021.10.31