import enoki as ek
import mitsuba
mitsuba.set_variant('gpu_autodiff_rgb')
from mitsuba.core import Thread, Color3f, Bitmap, Struct
from mitsuba.core.xml import load_file
from mitsuba.python.util import traverse
from mitsuba.python.autodiff import render, write_bitmap, Adam
import time
%matplotlib notebook
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['savefig.dpi'] = 80
mpl.rcParams['figure.dpi'] = 80
import os
import numpy as np
# コーネルボックスを読み込む
Thread.thread().file_resolver().append('cbox')
scene = load_file('cbox.xml')
# 微分可能なシーンパラメータを探す
params = traverse(scene)
# 微分したいパラメータを除いて、すべてのパラメータを破棄する
params.keep(['red.reflectance.value'])
# 現在の値を出力し、バックアップを生成する
param_ref = Color3f(params['red.reflectance.value'])
print(param_ref)
# 目的画像をレンダリングする(まだ導関数は使用していない)
image_ref = render(scene, spp=8)
bmp = Bitmap(np.array(image_ref).reshape(256, 256, 3))
bmp_rgb = bmp.convert(Bitmap.PixelFormat.RGB, Struct.Type.Float32, srgb_gamma=True)
# 結果を表示
plt.figure(figsize=(5, 5))
plt.imshow(bmp_rgb)
plt.title("Target")
plt.grid("off")
plt.axis("off")
plt.show()
# 左側の壁を白に変更する(初期値)
params['red.reflectance.value'] = [.9, .9, .9]
params.update()
#確認
image_white = render(scene, spp=8)
bmp = Bitmap(np.array(image_white).reshape(256, 256, 3))
bmp_rgb = bmp.convert(Bitmap.PixelFormat.RGB, Struct.Type.Float32, srgb_gamma=True)
plt.figure(figsize=(5, 5))
plt.imshow(bmp_rgb)
plt.title("Target")
plt.grid("off")
plt.axis("off")
plt.show()
# パラメータの最適化手法としてAdamを選択
opt = Adam(params, lr=.2)
time_a = time.time()
# イタレーション数
iterations = 100
for it in range(iterations):
# シーンの微分可能レンダリングを実行
image = render(scene, optimizer=opt, unbiased=True, spp=1)
# ロスを取る(レンダリング画像と目的画像の間のMSE)
ob_val = ek.hsum(ek.sqr(image - image_ref)) / len(image)
# 誤差を入力パラメータに逆伝搬する
ek.backward(ob_val)
# 勾配のステップを実行
opt.step()
# 反射率についてGround Truthの値と比較した結果を出力する
err_ref = ek.hsum(ek.sqr(param_ref - params['red.reflectance.value']))
print('Iteration %03i: error=%g' % (it, err_ref[0]), end='\r')
# 結果を表示
if (it % 5 == 0):
bmp = Bitmap(np.array(image).reshape(256, 256, 3))
bmp_rgb = bmp.convert(Bitmap.PixelFormat.RGB, Struct.Type.Float32, srgb_gamma=True)
plt.figure(figsize=(5, 5))
plt.imshow(bmp_rgb)
plt.title("Target")
plt.grid("off")
plt.axis("off")
plt.show()
time_b = time.time()
# イタレーション毎の計算時間の平均を出力
print()
print('%f ms per iteration' % (((time_b - time_a) * 1000) / iterations))
# 正解データ
print(param_ref)
# 学習後のデータ(初期値(0.9, 0.9, 0.9))
print(params['red.reflectance.value'])