In [1]:
import enoki as ek
import mitsuba
mitsuba.set_variant('gpu_autodiff_rgb')

from mitsuba.core import Thread, Color3f, Bitmap, Struct
from mitsuba.core.xml import load_file
from mitsuba.python.util import traverse
from mitsuba.python.autodiff import render, write_bitmap, Adam
import time

%matplotlib notebook
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['savefig.dpi'] = 80
mpl.rcParams['figure.dpi'] = 80

import os
import numpy as np
In [2]:
# コーネルボックスを読み込む
Thread.thread().file_resolver().append('cbox')
scene = load_file('cbox.xml')
2020-04-12 20:24:02 INFO main [xml.cpp:1129] Loading XML file "cbox.xml" ..
2020-04-12 20:24:02 INFO main [xml.cpp:1130] Using variant "gpu_autodiff_rgb"
2020-04-12 20:24:02 INFO main [PluginManager] Loading plugin "plugins/regular.so" ..
2020-04-12 20:24:02 INFO main [PluginManager] Loading plugin "plugins/srgb.so" ..
2020-04-12 20:24:02 INFO main [PluginManager] Loading plugin "plugins/srgb_d65.so" ..
2020-04-12 20:24:02 INFO main [PluginManager] Loading plugin "plugins/path.so" ..
2020-04-12 20:24:02 INFO main [PluginManager] Loading plugin "plugins/independent.so" ..
2020-04-12 20:24:02 INFO main [PluginManager] Loading plugin "plugins/gaussian.so" ..
2020-04-12 20:24:02 INFO main [PluginManager] Loading plugin "plugins/hdrfilm.so" ..
2020-04-12 20:24:02 INFO main [PluginManager] Loading plugin "plugins/perspective.so" ..
2020-04-12 20:24:02 INFO main [PluginManager] Loading plugin "plugins/diffuse.so" ..
2020-04-12 20:24:02 INFO main [PluginManager] Loading plugin "plugins/area.so" ..
2020-04-12 20:24:02 INFO main [PluginManager] Loading plugin "plugins/uniform.so" ..
2020-04-12 20:24:02 INFO main [PluginManager] Loading plugin "plugins/obj.so" ..
2020-04-12 20:24:02 INFO main [Scene] Validating and building scene in OptiX.
In [3]:
# 微分可能なシーンパラメータを探す
params = traverse(scene)
In [4]:
# 微分したいパラメータを除いて、すべてのパラメータを破棄する
params.keep(['red.reflectance.value'])
In [5]:
# 現在の値を出力し、バックアップを生成する
param_ref = Color3f(params['red.reflectance.value'])
print(param_ref)
[[0.569717, 0.0430141, 0.0443234]]
In [6]:
# 目的画像をレンダリングする(まだ導関数は使用していない)
image_ref = render(scene, spp=8)
bmp = Bitmap(np.array(image_ref).reshape(256, 256, 3))
bmp_rgb = bmp.convert(Bitmap.PixelFormat.RGB, Struct.Type.Float32, srgb_gamma=True)

# 結果を表示
plt.figure(figsize=(5, 5))
plt.imshow(bmp_rgb)
plt.title("Target")
plt.grid("off")
plt.axis("off")
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
In [7]:
# 左側の壁を白に変更する(初期値)
params['red.reflectance.value'] = [.9, .9, .9]
params.update()

#確認
image_white = render(scene, spp=8)
bmp = Bitmap(np.array(image_white).reshape(256, 256, 3))
bmp_rgb = bmp.convert(Bitmap.PixelFormat.RGB, Struct.Type.Float32, srgb_gamma=True)
plt.figure(figsize=(5, 5))
plt.imshow(bmp_rgb)
plt.title("Target")
plt.grid("off")
plt.axis("off")
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
In [8]:
# パラメータの最適化手法としてAdamを選択
opt = Adam(params, lr=.2)
In [9]:
time_a = time.time()

# イタレーション数
iterations = 100

for it in range(iterations):
    # シーンの微分可能レンダリングを実行
    image = render(scene, optimizer=opt, unbiased=True, spp=1)

    # ロスを取る(レンダリング画像と目的画像の間のMSE)
    ob_val = ek.hsum(ek.sqr(image - image_ref)) / len(image)

    # 誤差を入力パラメータに逆伝搬する
    ek.backward(ob_val)

    # 勾配のステップを実行
    opt.step()

    # 反射率についてGround Truthの値と比較した結果を出力する
    err_ref = ek.hsum(ek.sqr(param_ref - params['red.reflectance.value']))
    print('Iteration %03i: error=%g' % (it, err_ref[0]), end='\r')

    # 結果を表示
    if (it % 5 == 0):
        bmp = Bitmap(np.array(image).reshape(256, 256, 3))
        bmp_rgb = bmp.convert(Bitmap.PixelFormat.RGB, Struct.Type.Float32, srgb_gamma=True)
        plt.figure(figsize=(5, 5))
        plt.imshow(bmp_rgb)
        plt.title("Target")
        plt.grid("off")
        plt.axis("off")
        plt.show()

time_b = time.time()
Iteration 000: error=0.879099
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Iteration 005: error=0.0948403
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Iteration 010: error=0.374994
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Iteration 015: error=0.090083
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Iteration 020: error=0.02590344
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Iteration 025: error=0.0696904
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Iteration 030: error=0.0028938
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Iteration 035: error=0.018279729
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Iteration 040: error=0.0103674
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Iteration 045: error=0.001207772
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Iteration 050: error=0.00544792
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Iteration 055: error=0.000401953
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Iteration 060: error=0.001983226
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Iteration 065: error=0.000289912
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Iteration 070: error=0.000562276
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Iteration 075: error=5.02116e-05
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Iteration 080: error=0.000431787
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Iteration 085: error=5.04634e-05
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Iteration 090: error=3.58091e-05
/home/demo/.conda/envs/mitsuba2/lib/python3.7/site-packages/ipykernel_launcher.py:27: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Iteration 095: error=3.62908e-05
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Iteration 099: error=0.000516669
In [10]:
# イタレーション毎の計算時間の平均を出力
print()
print('%f ms per iteration' % (((time_b - time_a) * 1000) / iterations))

# 正解データ
print(param_ref)
# 学習後のデータ(初期値(0.9, 0.9, 0.9))
print(params['red.reflectance.value'])
81.049168 ms per iteration
[[0.569717, 0.0430141, 0.0443234]]
[[0.549746, 0.035356, 0.036632]]
In [ ]: