In [1]:
import enoki as ek
import mitsuba
mitsuba.set_variant('gpu_autodiff_rgb')

from mitsuba.core import Float, Thread, Bitmap, Struct           # add Bitmap, Struct
from mitsuba.core.xml import load_file
from mitsuba.python.util import traverse
from mitsuba.python.autodiff import render, write_bitmap, Adam
import time

# add
%matplotlib notebook
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['savefig.dpi'] = 80
mpl.rcParams['figure.dpi'] = 80

import os
import numpy as np
from PIL import Image
from io import BytesIO 
In [2]:
# シーン(cow_mesh)を読み込む
scene = load_file('head.xml')
2020-05-14 01:50:33 INFO main [xml.cpp:1129] Loading XML file "head.xml" ..
2020-05-14 01:50:33 INFO main [xml.cpp:1130] Using variant "gpu_autodiff_rgb"
2020-05-14 01:50:33 INFO main [PluginManager] Loading plugin "plugins/srgb_d65.so" ..
2020-05-14 01:50:33 INFO main [PluginManager] Loading plugin "plugins/path.so" ..
2020-05-14 01:50:33 INFO main [PluginManager] Loading plugin "plugins/independent.so" ..
2020-05-14 01:50:33 INFO main [PluginManager] Loading plugin "plugins/gaussian.so" ..
2020-05-14 01:50:33 INFO main [PluginManager] Loading plugin "plugins/hdrfilm.so" ..
2020-05-14 01:50:33 INFO main [PluginManager] Loading plugin "plugins/perspective.so" ..
2020-05-14 01:50:33 INFO main [PluginManager] Loading plugin "plugins/constant.so" ..
2020-05-14 01:50:33 INFO main [PluginManager] Loading plugin "plugins/uniform.so" ..
2020-05-14 01:50:33 INFO main [PluginManager] Loading plugin "plugins/bitmap.so" ..
2020-05-14 01:50:33 INFO main [PluginManager] Loading plugin "plugins/diffuse.so" ..
2020-05-14 01:50:33 INFO main [PluginManager] Loading plugin "plugins/obj.so" ..
2020-05-14 01:50:33 INFO main [Scene] Validating and building scene in OptiX.
In [3]:
# 微分可能なシーンパラメータを探す
params = traverse(scene)
print(params)
ParameterMap[
    PerspectiveCamera.near_clip,
    PerspectiveCamera.far_clip,
    PerspectiveCamera.focus_distance,
    PerspectiveCamera.shutter_open,
    PerspectiveCamera.shutter_open_time,
  * ConstantBackgroundEmitter.radiance.value,
  * my_image_000.data,
    my_image_000.resolution,
    my_image_000.transform,
  * material_000.reflectance.data,
    material_000.reflectance.resolution,
    material_000.reflectance.transform,
    OBJMesh.vertex_count,
    OBJMesh.face_count,
    OBJMesh.faces,
  * OBJMesh.vertex_positions,
  * OBJMesh.vertex_normals,
  * OBJMesh.vertex_texcoords,
]
In [4]:
# 画像の配列とサイズを渡してmatplotlibで描画する関数
def plotBMP(data, size, title, figsize):
    channel = len(data) // (size[0]*size[1])
    print("size={0}, {1}, {2}".format(size[1], size[0], channel))
    bmp = Bitmap(np.array(data).reshape(size[1], size[0], channel))
    bmp_rgb = bmp.convert(Bitmap.PixelFormat.RGB, Struct.Type.Float32, srgb_gamma=True)

    # 表示
    plt.figure(figsize=figsize)
    plt.imshow(bmp_rgb)
    plt.title(title)
    plt.grid("off")
    plt.axis("off")
    plt.show()

# 左右に並べて比較
def tileBMP(data1, data2, size, figsize):
    channel = len(data1) // (size[0]*size[1])
    print("size={0}, {1}, {2}".format(size[1], size[0], channel))
    bmp1 = Bitmap(np.array(data1).reshape(size[1], size[0], channel))
    bmp_rgb1 = bmp1.convert(Bitmap.PixelFormat.RGB, Struct.Type.Float32, srgb_gamma=True)
    bmp2 = Bitmap(np.array(data2).reshape(size[1], size[0], channel))
    bmp_rgb2 = bmp2.convert(Bitmap.PixelFormat.RGB, Struct.Type.Float32, srgb_gamma=True)

    # 表示
    plt.figure(figsize=figsize)
    plt.subplot(1, 2, 1)
    plt.imshow(bmp_rgb1)
    plt.grid("off")
    plt.axis("off")
    plt.subplot(1, 2, 2)
    plt.imshow(bmp_rgb2)
    plt.grid("off")
    plt.axis("off")
    plt.show()
    
In [5]:
# バックアップを作成する
param_res = params['material_000.reflectance.resolution']
param_ref = Float(params['material_000.reflectance.data'])
print("resolution=", param_res)
print("data=", param_ref)

# 確認
plotBMP(param_ref, param_res, "head_texture.png", (5, 5))
resolution= [4096, 4096]
data= [0.637597, 0.371238, 0.283149, 0.637597, 0.371238, .. 50331638 skipped .., 0.371238, 0.283149, 0.637597, 0.371238, 0.283149]
size=4096, 4096, 3
In [6]:
# 微分したいパラメータを除いて、すべてのパラメータを破棄する
params.keep(['material_000.reflectance.data'])

# 目的画像をレンダリングする(まだ導関数は使用していない)
#image_ref = render(scene, spp=16)
with open('FaceRig.jpg', 'rb') as f:
    binary = f.read()
img = Image.open(BytesIO(binary))
image_ref = np.asarray(img)
print(image_ref.shape)

# 表示
plt.figure(figsize=(5,5))
plt.imshow(image_ref)
plt.title("reference")
plt.grid("off")
plt.axis("off")
plt.show()

image_ref = image_ref.reshape(512*512*3)
image_ref = image_ref / 255.0
print(image_ref.shape)
(512, 512, 3)
(786432,)
In [7]:
# 白で統一されたテクスチャマップに変更(初期画像)
#params['material_000.reflectance.data'] = ek.full(Float, 1.0, len(param_ref))
#params.update()

# 確認
image = render(scene, spp=4)
crop_size = scene.sensors()[0].film().crop_size()
plotBMP(image, crop_size, "image", (5,5))
size=512, 512, 3
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
In [8]:
ob_val = ek.hsum(ek.sqr(image - image_ref)) / len(image)
In [9]:
# パラメータの最適化手法としてAdamを選択
opt = Adam(params, lr=.02)

# 学習
time_a = time.time()

iterations = 501

for it in range(iterations):
    # シーンの微分可能レンダリングを実行
    image = render(scene, optimizer=opt, unbiased=True, spp=8)
    
    # ロスを取る(レンダリング画像と目的画像のMSE)
    ob_val = ek.hsum(ek.sqr(image - image_ref)) / len(image)
    
    # 誤差を入力パラメータに逆伝搬する
    ek.backward(ob_val)
    
    # 勾配のステップを実行
    opt.step()
    
    # 環境光源マップについてGround Truthの値と比較した結果を出力する
#    err_ref = ek.hsum(ek.sqr(param_ref - params['material_000.reflectance.data']))
#    print('Iteration %03i: error=%g' % (it, err_ref[0]), end='\r')
    print('Iteration %03i: error=%g' % (it, ob_val[0]), end='\r')
    
    if it % 50 == 0:
        plotBMP(image, crop_size, 'Iteration %03i' % it, (5, 5))
    
time_b = time.time()

print()
print('%f ms per iteration' % (((time_b - time_a) * 1000) / iterations))
size=512, 512, 3
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=0.0307134
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=0.0262247
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=0.0237799
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=0.0225756
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=0.0219451
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=0.0215988
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=0.0213745
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=0.0212193
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=0.0211554
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
size=512, 512, 3rror=0.0210774
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
146.057484 ms per iteration
In [10]:
# バックアップを作成する
result = params['material_000.reflectance.data']

# 確認
tileBMP(param_ref, result, param_res, (10, 5))
size=4096, 4096, 3
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
In [11]:
ret = np.asarray(result)
ret = ret.reshape(4096, 4096, 3)
ret *= 255.0
print(ret.shape)
pil_img = Image.fromarray(ret.astype(np.uint8))
print(pil_img.mode)
# RGB

pil_img.save('result.jpg')
(4096, 4096, 3)
RGB
In [ ]: