import enoki as ek
import mitsuba
mitsuba.set_variant('gpu_autodiff_rgb')
from mitsuba.core import Float, Thread, Bitmap, Struct # add Bitmap, Struct
from mitsuba.core.xml import load_file
from mitsuba.python.util import traverse
from mitsuba.python.autodiff import render, write_bitmap, Adam
import time
# add
%matplotlib notebook
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['savefig.dpi'] = 80
mpl.rcParams['figure.dpi'] = 80
import os
import numpy as np
from PIL import Image
from io import BytesIO
# シーン(cow_mesh)を読み込む
scene = load_file('result_test2.xml')
# 微分可能なシーンパラメータを探す
params = traverse(scene)
print(params)
# 画像の配列とサイズを渡してmatplotlibで描画する関数
def plotBMP(data, size, title, figsize):
channel = len(data) // (size[0]*size[1])
print("size={0}, {1}, {2}".format(size[1], size[0], channel))
bmp = Bitmap(np.array(data).reshape(size[1], size[0], channel))
bmp_rgb = bmp.convert(Bitmap.PixelFormat.RGB, Struct.Type.Float32, srgb_gamma=True)
# 表示
plt.figure(figsize=figsize)
plt.imshow(bmp_rgb)
plt.title(title)
plt.grid("off")
plt.axis("off")
plt.show()
# 左右に並べて比較
def tileBMP(data1, data2, size, figsize):
channel = len(data1) // (size[0]*size[1])
print("size={0}, {1}, {2}".format(size[1], size[0], channel))
bmp1 = Bitmap(np.array(data1).reshape(size[1], size[0], channel))
bmp_rgb1 = bmp1.convert(Bitmap.PixelFormat.RGB, Struct.Type.Float32, srgb_gamma=True)
bmp2 = Bitmap(np.array(data2).reshape(size[1], size[0], channel))
bmp_rgb2 = bmp2.convert(Bitmap.PixelFormat.RGB, Struct.Type.Float32, srgb_gamma=True)
# 表示
plt.figure(figsize=figsize)
plt.subplot(1, 2, 1)
plt.imshow(bmp_rgb1)
plt.grid("off")
plt.axis("off")
plt.subplot(1, 2, 2)
plt.imshow(bmp_rgb2)
plt.grid("off")
plt.axis("off")
plt.show()
# バックアップを作成する
param_res = params['material_000.reflectance.resolution']
param_ref = Float(params['material_000.reflectance.data'])
print("resolution=", param_res)
print("data=", param_ref)
# 確認
plotBMP(param_ref, param_res, "uv_texture", (5, 2.5))
env_res = params['my_envmap.resolution']
env_ref = Float(params['my_envmap.data'])
print("resolution=", env_res)
print("data=", env_ref)
# 確認
plotBMP(env_ref, env_res, "uv_texture", (5, 2.5))
# 微分したいパラメータを除いて、すべてのパラメータを破棄する
params.keep(['material_000.reflectance.data','my_envmap.data'])
#params.keep(['material_000.reflectance.data'])
with open('test.png', 'rb') as f:
binary = f.read()
img = Image.open(BytesIO(binary))
image_ref = np.asarray(img)
print(image_ref.shape)
# 表示
plt.figure(figsize=(5,5))
plt.imshow(image_ref)
plt.title("reference")
plt.grid("off")
plt.axis("off")
plt.show()
image_ref = image_ref.reshape(512*512*3)
image_ref = image_ref / 255.0
print(image_ref.shape)
# 白で統一されたテクスチャマップに変更(初期画像)
params['material_000.reflectance.data'] = ek.full(Float, 1.0, len(param_ref))
params['my_envmap.data'] = ek.full(Float, 0.1, len(env_ref))
params.update()
# 確認
image = render(scene, spp=16)
crop_size = scene.sensors()[0].film().crop_size()
plotBMP(image, crop_size, "image", (5,5))
ob_val = ek.hsum(ek.sqr(image - image_ref)) / len(image)
# パラメータの最適化手法としてAdamを選択
opt = Adam(params, lr=.001)
# 学習
time_a = time.time()
iterations = 1001
for it in range(iterations):
# シーンの微分可能レンダリングを実行
image = render(scene, optimizer=opt, unbiased=True, spp=8)
# ロスを取る(レンダリング画像と目的画像のMSE)
ob_val = ek.hsum(ek.sqr(image - image_ref)) / len(image)
# 誤差を入力パラメータに逆伝搬する
ek.backward(ob_val)
# 勾配のステップを実行
opt.step()
# 環境光源マップについてGround Truthの値と比較した結果を出力する
# err_ref = ek.hsum(ek.sqr(param_ref - params['material_000.reflectance.data']))
# print('Iteration %03i: error=%g' % (it, err_ref[0]), end='\r')
print('Iteration %03i: error=%g' % (it, ob_val[0]), end='\r')
if it % 50 == 0:
plotBMP(image, crop_size, 'Iteration %03i' % it, (5, 5))
time_b = time.time()
print()
print('%f ms per iteration' % (((time_b - time_a) * 1000) / iterations))
# 結果の取得
result = params['material_000.reflectance.data']
tileBMP(param_ref, result, param_res, (10, 5))
result_env = params['my_envmap.data']
tileBMP(env_ref, result_env, env_res, (10, 5))
ret = np.asarray(result)
ret = ret.reshape(500, 500, 3)
ret *= 255.0
ret = np.clip(ret, 0.0, 255.0)
print(ret.shape)
pil_img = Image.fromarray(ret.astype(np.uint8))
print(pil_img.mode)
# RGB
pil_img.save('texture.jpg')
ret = np.asarray(image)
ret = ret.reshape(512, 512, 3)
ret *= 255.0
ret = np.clip(ret, 0.0, 255.0)
print(ret.shape)
pil_img = Image.fromarray(ret.astype(np.uint8))
print(pil_img.mode)
# RGB
pil_img.save('result.jpg')