In [1]:
import enoki as ek
import mitsuba
mitsuba.set_variant('gpu_autodiff_rgb')

from mitsuba.core import Float, Thread, Bitmap, Struct           # add Bitmap, Struct
from mitsuba.core.xml import load_file
from mitsuba.python.util import traverse
from mitsuba.python.autodiff import render, write_bitmap, Adam
import time

# add
%matplotlib notebook
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['savefig.dpi'] = 80
mpl.rcParams['figure.dpi'] = 80

import os
import numpy as np
In [2]:
# シーン(cow_mesh)を読み込む
scene = load_file('cow_mesh.xml')
2020-04-13 18:30:59 INFO main [xml.cpp:1129] Loading XML file "cow_mesh.xml" ..
2020-04-13 18:30:59 INFO main [xml.cpp:1130] Using variant "gpu_autodiff_rgb"
2020-04-13 18:30:59 INFO main [PluginManager] Loading plugin "plugins/path.so" ..
2020-04-13 18:30:59 INFO main [PluginManager] Loading plugin "plugins/envmap.so" ..
2020-04-13 18:30:59 INFO main [PluginManager] Loading plugin "plugins/uniform.so" ..
2020-04-13 18:30:59 INFO main [PluginManager] Loading plugin "plugins/bitmap.so" ..
2020-04-13 18:30:59 INFO main [PluginManager] Loading plugin "plugins/diffuse.so" ..
2020-04-13 18:30:59 INFO main [PluginManager] Loading plugin "plugins/obj.so" ..
2020-04-13 18:30:59 INFO main [PluginManager] Loading plugin "plugins/hdrfilm.so" ..
2020-04-13 18:30:59 INFO main [PluginManager] Loading plugin "plugins/gaussian.so" ..
2020-04-13 18:30:59 INFO main [PluginManager] Loading plugin "plugins/independent.so" ..
2020-04-13 18:30:59 INFO main [PluginManager] Loading plugin "plugins/perspective.so" ..
2020-04-13 18:30:59 INFO main [Scene] Validating and building scene in OptiX.
In [3]:
# 微分可能なシーンパラメータを探す
params = traverse(scene)
print(params)
ParameterMap[
    my_envmap.scale,
  * my_envmap.data,
    my_envmap.resolution,
  * myImage.data,
    myImage.resolution,
    myImage.transform,
  * myMaterial.reflectance.data,
    myMaterial.reflectance.resolution,
    myMaterial.reflectance.transform,
    OBJMesh.vertex_count,
    OBJMesh.face_count,
    OBJMesh.faces,
  * OBJMesh.vertex_positions,
  * OBJMesh.vertex_normals,
  * OBJMesh.vertex_texcoords,
    PerspectiveCamera.near_clip,
    PerspectiveCamera.far_clip,
    PerspectiveCamera.focus_distance,
    PerspectiveCamera.shutter_open,
    PerspectiveCamera.shutter_open_time,
]
In [4]:
# 画像の配列とサイズを渡してmatplotlibで描画する関数
def plotBMP(data, size, title, figsize):
    channel = len(data) // (size[0]*size[1])
    print("size={0}, {1}, {2}".format(size[1], size[0], channel))
    bmp = Bitmap(np.array(data).reshape(size[1], size[0], channel))
    bmp_rgb = bmp.convert(Bitmap.PixelFormat.RGB, Struct.Type.Float32, srgb_gamma=True)

    # 表示
    plt.figure(figsize=figsize)
    plt.imshow(bmp_rgb)
    plt.title(title)
    plt.grid("off")
    plt.axis("off")
    plt.show()

# 左右に並べて比較
def tileBMP(data1, data2, size, figsize):
    channel = len(data1) // (size[0]*size[1])
    print("size={0}, {1}, {2}".format(size[1], size[0], channel))
    bmp1 = Bitmap(np.array(data1).reshape(size[1], size[0], channel))
    bmp_rgb1 = bmp1.convert(Bitmap.PixelFormat.RGB, Struct.Type.Float32, srgb_gamma=True)
    bmp2 = Bitmap(np.array(data2).reshape(size[1], size[0], channel))
    bmp_rgb2 = bmp2.convert(Bitmap.PixelFormat.RGB, Struct.Type.Float32, srgb_gamma=True)

    # 表示
    plt.figure(figsize=figsize)
    plt.subplot(1, 2, 1)
    plt.imshow(bmp_rgb1)
    plt.grid("off")
    plt.axis("off")
    plt.subplot(1, 2, 2)
    plt.imshow(bmp_rgb2)
    plt.grid("off")
    plt.axis("off")
    plt.show()
    
In [5]:
# バックアップを作成する
param_res = params['myMaterial.reflectance.resolution']
param_ref = Float(params['myMaterial.reflectance.data'])
print("resolution=", param_res)
print("data=", param_ref)

# 確認
plotBMP(param_ref, param_res, "cow_texture.png", (5, 2.5))
resolution= [1024, 1024]
data= [1, 0.854992, 0.791298, 1, 0.854992, .. 3145718 skipped .., 0.854992, 0.791298, 1, 0.854992, 0.791298]
size=1024, 1024, 3
In [6]:
# 微分したいパラメータを除いて、すべてのパラメータを破棄する
params.keep(['myMaterial.reflectance.data'])

# 目的画像をレンダリングする(まだ導関数は使用していない)
image_ref = render(scene, spp=16)
crop_size = scene.sensors()[0].film().crop_size()

# 確認
plotBMP(image_ref, crop_size, "image_ref", (5,5))
size=512, 512, 3
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
In [7]:
# 白で統一されたテクスチャマップに変更(初期画像)
params['myMaterial.reflectance.data'] = ek.full(Float, 1.0, len(param_ref))
params.update()

# 確認
image = render(scene, spp=16)
plotBMP(image, crop_size, "image", (5,5))
size=512, 512, 3
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
In [8]:
# パラメータの最適化手法としてAdamを選択
opt = Adam(params, lr=.02)

# 学習
time_a = time.time()

iterations = 501

for it in range(iterations):
    # シーンの微分可能レンダリングを実行
    image = render(scene, optimizer=opt, unbiased=True, spp=1)
    
    # ロスを取る(レンダリング画像と目的画像のMSE)
    ob_val = ek.hsum(ek.sqr(image - image_ref)) / len(image)
    
    # 誤差を入力パラメータに逆伝搬する
    ek.backward(ob_val)
    
    # 勾配のステップを実行
    opt.step()
    
    # 環境光源マップについてGround Truthの値と比較した結果を出力する
    err_ref = ek.hsum(ek.sqr(param_ref - params['myMaterial.reflectance.data']))
    print('Iteration %03i: error=%g' % (it, err_ref[0]), end='\r')
    
    if it % 50 == 0:
        plotBMP(image, crop_size, 'Iteration %03i' % it, (5, 5))
    
time_b = time.time()

print()
print('%f ms per iteration' % (((time_b - time_a) * 1000) / iterations))
size=512, 512, 3
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=249205
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=232652
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=225810
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=222205
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=220230
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=219348
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=219181
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=219415
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=220082
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=221030
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
45.454436 ms per iteration
In [9]:
# バックアップを作成する
result = params['myMaterial.reflectance.data']

# 確認
tileBMP(param_ref, result, param_res, (10, 5))
size=1024, 1024, 3
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
In [ ]: