In [1]:
import enoki as ek
import mitsuba
mitsuba.set_variant('gpu_autodiff_rgb')

from mitsuba.core import Float, Thread, Bitmap, Struct           # add Bitmap, Struct
from mitsuba.core.xml import load_file
from mitsuba.python.util import traverse
from mitsuba.python.autodiff import render, write_bitmap, Adam
import time

# add
%matplotlib notebook
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['savefig.dpi'] = 80
mpl.rcParams['figure.dpi'] = 80

import os
import numpy as np
In [2]:
# シーン(bunny)を読み込む
scene = load_file('bunny.xml')
2020-04-13 15:59:43 INFO main [xml.cpp:1129] Loading XML file "bunny.xml" ..
2020-04-13 15:59:43 INFO main [xml.cpp:1130] Using variant "gpu_autodiff_rgb"
2020-04-13 15:59:43 INFO main [PluginManager] Loading plugin "plugins/path.so" ..
2020-04-13 15:59:43 INFO main [PluginManager] Loading plugin "plugins/envmap.so" ..
2020-04-13 15:59:43 INFO main [PluginManager] Loading plugin "plugins/uniform.so" ..
2020-04-13 15:59:43 INFO main [PluginManager] Loading plugin "plugins/conductor.so" ..
2020-04-13 15:59:43 INFO main [PluginManager] Loading plugin "plugins/ply.so" ..
2020-04-13 15:59:43 WARN main [PLYMesh] "bunny.ply": performance warning -- this file uses the ASCII PLY format, which is slow to parse. Consider converting it to the binary PLY format.
2020-04-13 15:59:43 WARN main [Mesh] "bunny.ply": computed vertex normals (took 1ms, 1113 invalid vertices!)
2020-04-13 15:59:43 INFO main [PluginManager] Loading plugin "plugins/hdrfilm.so" ..
2020-04-13 15:59:43 INFO main [PluginManager] Loading plugin "plugins/gaussian.so" ..
2020-04-13 15:59:43 INFO main [PluginManager] Loading plugin "plugins/independent.so" ..
2020-04-13 15:59:43 INFO main [PluginManager] Loading plugin "plugins/perspective.so" ..
2020-04-13 15:59:43 INFO main [Scene] Validating and building scene in OptiX.
In [3]:
# 微分可能なシーンパラメータを探す
params = traverse(scene)
print(params)
ParameterMap[
    my_envmap.scale,
  * my_envmap.data,
    my_envmap.resolution,
  * PLYMesh.bsdf.specular_reflectance.value,
  * PLYMesh.bsdf.eta.value,
  * PLYMesh.bsdf.k.value,
    PLYMesh.vertex_count,
    PLYMesh.face_count,
    PLYMesh.faces,
  * PLYMesh.vertex_positions,
  * PLYMesh.vertex_normals,
  * PLYMesh.vertex_texcoords,
    PerspectiveCamera.near_clip,
    PerspectiveCamera.far_clip,
    PerspectiveCamera.focus_distance,
    PerspectiveCamera.shutter_open,
    PerspectiveCamera.shutter_open_time,
]
In [4]:
# 画像の配列とサイズを渡してmatplotlibで描画する関数
def plotBMP(data, size, title, figsize):
    channel = len(data) // (size[0]*size[1])
    print("size={0}, {1}, {2}".format(size[1], size[0], channel))
    bmp = Bitmap(np.array(data).reshape(size[1], size[0], channel))
    bmp_rgb = bmp.convert(Bitmap.PixelFormat.RGB, Struct.Type.Float32, srgb_gamma=True)

    # 表示
    plt.figure(figsize=figsize)
    plt.imshow(bmp_rgb)
    plt.title(title)
    plt.grid("off")
    plt.axis("off")
    plt.show()
In [5]:
# バックアップを作成する
param_res = params['my_envmap.resolution']
param_ref = Float(params['my_envmap.data'])
print("resolution=", param_res)
print("data=", param_ref)

# 確認
plotBMP(param_ref, param_res, "museum.exr", (5, 2.5))
resolution= [250, 125]
data= [0.149571, 0.0872033, 0.0425204, 1, 0.149695, .. 124990 skipped .., 1, 8.04265e-05, 7.37574e-06, -4.16967e-05, 1]
size=125, 250, 4
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
In [6]:
# 微分したいパラメータを除いて、すべてのパラメータを破棄する
params.keep(['my_envmap.data'])

# 目的画像をレンダリングする(まだ導関数は使用していない)
image_ref = render(scene, spp=16)
crop_size = scene.sensors()[0].film().crop_size()

# 確認
plotBMP(image_ref, crop_size, "image_ref", (5,5))
size=512, 512, 3
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
In [7]:
# 白で統一された環境証明マップに変更(初期画像)
params['my_envmap.data'] = ek.full(Float, 1.0, len(param_ref))
params.update()

# 確認
image = render(scene, spp=16)
plotBMP(image, crop_size, "image", (5,5))
size=512, 512, 3
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
In [8]:
# パラメータの最適化手法としてAdamを選択
opt = Adam(params, lr=.02)
In [9]:
# 学習
time_a = time.time()

iterations = 101

for it in range(iterations):
    # シーンの微分可能レンダリングを実行
    image = render(scene, optimizer=opt, unbiased=True, spp=1)
    
    # ロスを取る(レンダリング画像と目的画像のMSE)
    ob_val = ek.hsum(ek.sqr(image - image_ref)) / len(image)
    
    # 誤差を入力パラメータに逆伝搬する
    ek.backward(ob_val)
    
    # 勾配のステップを実行
    opt.step()
    
    # 環境光源マップについてGround Truthの値と比較した結果を出力する
    err_ref = ek.hsum(ek.sqr(param_ref - params['my_envmap.data']))
    print('Iteration %03i: error=%g' % (it, err_ref[0]), end='\r')
    
    if it % 10 == 0:
        plotBMP(image, crop_size, 'Iteration %03i' % it, (5, 5))
    
time_b = time.time()

print()
print('%f ms per iteration' % (((time_b - time_a) * 1000) / iterations))
size=512, 512, 3
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=61333.2
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=42739.2
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=29933.2
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=21876.4
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=17170.2
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=14512.1
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=12932.1
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=11878.4
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=11081.3
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=10444.9
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
45.544554 ms per iteration
In [10]:
err_ref = ek.hsum(ek.sqr(param_ref - params['my_envmap.data']))
print('Iteration %03i: error=%g' % (it, err_ref[0]), end='\r')
Iteration 100: error=10444.9
In [ ]: