Skip to content

Commit

Permalink
add use_opengl flag
Browse files Browse the repository at this point in the history
  • Loading branch information
sicxu committed Sep 13, 2022
1 parent 6eeb3b9 commit 1d8e2b8
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 7 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ python test.py --name=<model_name> --epoch=20 --img_folder=<folder_to_test_image
# get reconstruction results of example images
python test.py --name=<model_name> --epoch=20 --img_folder=./datasets/examples
```
**_Following [#108](https://github.com/sicxu/Deep3DFaceRecon_pytorch/issues/108), if you don't have OpenGL environment, you can simply add "--use_opengl False" to use CUDA context. Make sure you have updated the nvdiffrast to the latest version._**

Results will be saved into ./checkpoints/<model_name>/results/<folder_to_test_images>, which contain the following files:
| \*.png | A combination of cropped input image, reconstructed image, and visualization of projected landmarks.
Expand Down
3 changes: 2 additions & 1 deletion models/facerecon_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def modify_commandline_options(parser, is_train=True):
parser.add_argument('--camera_d', type=float, default=10.)
parser.add_argument('--z_near', type=float, default=5.)
parser.add_argument('--z_far', type=float, default=15.)
parser.add_argument('--use_opengl', type=util.str2bool, nargs='?', const=True, default=False, help='use opengl context or not')

if is_train:
# training parameters
Expand Down Expand Up @@ -97,7 +98,7 @@ def __init__(self, opt):

fov = 2 * np.arctan(opt.center / opt.focal) * 180 / np.pi
self.renderer = MeshRenderer(
rasterize_fov=fov, znear=opt.z_near, zfar=opt.z_far, rasterize_size=int(2 * opt.center)
rasterize_fov=fov, znear=opt.z_near, zfar=opt.z_far, rasterize_size=int(2 * opt.center), use_opengl=opt.use_opengl
)

if self.isTrain:
Expand Down
1 change: 1 addition & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def main(rank, opt, name='examples'):
print(i, im_path[i])
img_name = im_path[i].split(os.path.sep)[-1].replace('.png','').replace('.jpg','')
if not os.path.isfile(lm_path[i]):
print("%s is not found !!!"%lm_path[i])
continue
im_tensor, lm_tensor = read_data(im_path[i], lm_path[i], lm3d_std)
data = {
Expand Down
19 changes: 13 additions & 6 deletions util/nvdiffrast.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,16 @@ def __init__(self,
rasterize_fov,
znear=0.1,
zfar=10,
rasterize_size=224):
rasterize_size=224,
use_opengl=True):
super(MeshRenderer, self).__init__()

x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear
self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul(
torch.diag(torch.tensor([1., -1, -1, 1])))
self.rasterize_size = rasterize_size
self.glctx = None
self.use_opengl = use_opengl
self.ctx = None

def forward(self, vertex, tri, feat=None):
"""
Expand All @@ -54,9 +56,14 @@ def forward(self, vertex, tri, feat=None):


vertex_ndc = vertex @ ndc_proj.t()
if self.glctx is None:
self.glctx = dr.RasterizeGLContext(device=device)
print("create glctx on device cuda:%d"%device.index)
if self.ctx is None:
if self.use_opengl:
self.ctx = dr.RasterizeGLContext(device=device)
ctx_str = "opengl"
else:
self.ctx = dr.RasterizeCudaContext(device=device)
ctx_str = "cuda"
print("create %s ctx on device cuda:%d"%(ctx_str, device.index))

ranges = None
if isinstance(tri, List) or len(tri.shape) == 3:
Expand All @@ -71,7 +78,7 @@ def forward(self, vertex, tri, feat=None):

# for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3]
tri = tri.type(torch.int32).contiguous()
rast_out, _ = dr.rasterize(self.glctx, vertex_ndc.contiguous(), tri, resolution=[rsize, rsize], ranges=ranges)
rast_out, _ = dr.rasterize(self.ctx, vertex_ndc.contiguous(), tri, resolution=[rsize, rsize], ranges=ranges)

depth, _ = dr.interpolate(vertex.reshape([-1,4])[...,2].unsqueeze(1).contiguous(), rast_out, tri)
depth = depth.permute(0, 3, 1, 2)
Expand Down

0 comments on commit 1d8e2b8

Please sign in to comment.