def get_rays(H, W, K, c2w): i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij' i = i.t() j = j.t() dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1) # normalized coordinate # Rotate ray directions from camera frame to the world frame rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)..