下手なdllよりもluajitの方が速いという話
平方根の高速化
平方根(sqrt関数)を高速に求めたかったので、fast sqrt algorithmというものを学びました。平方根としての計算精度を犠牲にして、演算数を少なくしたものです。実際にc++のstd::sqrtと比較してみると、確かに高速です。これをluaでも呼び出せば、luaスクリプトの実行速度が速くなるのではと思い、試してみました。 fast sqrt algorithmをluaでも
M_Blur_Moduleにあるtime関数を使用して処理速度を比較してみました。
まずは、純粋なluaコードです。
code:lua
<?
local d=require("M_Blur_Module")
local t1=d.time()
local s=0
for i=1,0xffffff do
s=s+math.sqrt(i)
end
debug_print(string.format("%f : %f",d.time()-t1,s))
?>
大体、40ぐらいの時間がかかりました。
次は、math.sqrtをfast sqrt algorithmに置き換えて処理速度を見てみます。
code:lua
local ffi=require("ffi")
local bit=require("bit")
math.sqrt = function(a)
local a_ = ffi.new("float1", a) local fa = ffi.cast("int*", a_)
local tmp = ffi.new("int1", 0x5f3700a0 - bit.rshift(fa0, 1)) local xk = ffi.cast("float*", tmp)
xk0 = xk0 * (1.5 - (hx * xk0 * xk0)) --xk0 = xk0 * (1.5 - (hx * xk0 * xk0)) end
大体、1200ぐらいの時間がかかりました。
デフォルトだと40で、精度を捨てて速度を優先しているfast sqrt algorithmだと1200です。fast sqrt algorithmの方が30倍遅いです。なにかおかしいです。
luaでのキャストが上手くいっていないのかと思い、c++でdllを書いてその関数を呼び出して処理時間を計測してみました。
code:M_FastSqrt_Module.cpp
extern "C" {
__declspec(dllexport) float sqrt_M(float a) {
float hx = 0.5 * a;
int tmp = 0x5f3700a0 - (std::bit_cast<int>(a) >> 1);
float xk = std::bit_cast<float>(tmp);
xk = xk * (1.5 - (hx * xk * xk));
//xk = xk * (1.5 - (hx * xk * xk));
return xk * a;
}
}
code:lua
<?
local d=require("M_Blur_Module")
local t1=d.time()
local ffi = require("ffi")
ffi.cdeffloat sqrt_M(float a);
local s = ffi.load("M_FastSqrt_Module")
math.sqrt=s.sqrt_M
local s=0
for i=1,0xffffff do
s=s+math.sqrt(i)
end
debug_print(string.format("%f : %f",d.time()-t1,s))
?>
それでも大体、80ぐらいの時間がかかりました。
SIMD命令と職人の御業
c++のコードよりもluaの方が高速に動作するというのはとても奇妙です。
ですが、よくよく考えてみると僕が使っていたのはluaではありません。僕が使っていたのはluajitです。luajitというのは、職人が1つ1つ手編みでアセンブラをかきかきしたりして製作された恐ろしい言語です。
実はfast sqrt algorithmは最速ではありません。最速はSIMD命令です。もしもluajitがSIMD命令を使用してmath.sqrtを実装しているのならば、c++もSIMD命令を使わないと処理速度で肩を並べることができないです。
c++でsimd命令を使って平方根を求めるプログラムを書き、標準std::sqrtなどと比較してみました。
code:main.cpp
class timecount {//速度計測用
public:
LARGE_INTEGER freq;
LARGE_INTEGER start, end;
timecount() {
QueryPerformanceFrequency(&freq);
QueryPerformanceCounter(&start);
}
~timecount() {
QueryPerformanceCounter(&end);
double time = static_cast<double>(end.QuadPart - start.QuadPart) * 1000.0 / freq.QuadPart;
std::cout << "time : " << time << std::endl;
}
};
float fsqrt(float a) {
float hx = 0.5 * a;
int tmp = 0x5F3700a0 - (std::bit_cast<int>(a) >> 1);
float xk = std::bit_cast<float>(tmp);
xk = xk * (1.5 - (hx * xk * xk));
//xk = xk * (1.5 - (hx * xk * xk));
return xk * a;
}
float sqrt_ss(float a) {
__m128 re = _mm_load_ss(&a);
re = _mm_sqrt_ss(re);
_mm_store_ss(&a, re);
return a;
}
float sqrt_ps(float a) {
__m128 re = _mm_load_ss(&a);
re = _mm_sqrt_ps(re);
_mm_store_ss(&a, re);
return a;
}
float fsqrt_ss(float a) {
__m128 re = _mm_load_ss(&a);
__m128 re2 = _mm_rsqrt_ss(re);
re = _mm_mul_ss(re, re2);
_mm_store_ss(&a, re);
return a;
}
int main() {
float s;
std::cout << std::sqrt(2.0) << std::endl;
std::cout << fsqrt(2.0) << std::endl;
std::cout << sqrt_ss(2.0) << std::endl;
std::cout << sqrt_ps(2.0) << std::endl;
std::cout << fsqrt_ss(2.0) << std::endl;
std::cout << std::endl;
s = 0;
{//std::sqrt
class timecount obj;
for (auto i = 0; i < 0x10000000; i++) s += std::sqrt(i);
}
std::cout << s << std::endl;
s = 0;
{//fsqrt
class timecount obj;
for (auto i = 0; i < 0x10000000; i++) s += fsqrt(i);
}
std::cout << s << std::endl;
s = 0;
{//sqrt_ss
class timecount obj;
for (auto i = 0; i < 0x10000000; i++) s += sqrt_ss(i);
}
std::cout << s << std::endl;
s = 0;
{//sqrt_ps
class timecount obj;
for (auto i = 0; i < 0x10000000; i++) s += sqrt_ps(i);
}
std::cout << s << std::endl;
s = 0;
{//fsqrt_ss
class timecount obj;
for (auto i = 0; i < 0x10000000; i++) s += fsqrt_ss(i);
}
std::cout << s << std::endl;
system("pause");
return 0;
}
結果は以下の表のようになりました。
table:結果
2の平方根 速度
std::sqrt 1.41421 1346.36
fsqrt 1.41396 836.426
sqrt_ss 1.41421 353.209
sqrt_ps 1.41421 348.431
fsqrt_ss 1.41382 350.802
SIMD命令にあまり詳しくなかったので、SIMD命令のコードは三種類書いて比較してみました。精度面と速度面から、sqrt_ssかsqrt_psかのどちらかの方法が最適であることが分かります。
sqrt_psのプログラムを用いてluajitのmath.sqrtと比較してみます。fast sqrt algorithmのときのdllを以下のdllに置き換えて速度を計りました。
code:M_FastSqrt_Module.cpp
extern "C" {
__declspec(dllexport) float sqrt_M(float a) {
__m128 re = _mm_load_ss(&a);
re = _mm_sqrt_ss(re);
_mm_store_ss(&a, re);
return a;
}
}
大体、55ぐらいの時間がかかりました。luajitのmath.sqrtだと40だったので、dllで呼び出している分のオーバーヘッドが差分の15なのだと思います。
dllよりもluajitが速い状況
ここで、dllよりもluajitで処理をした方がいい状況というものがいろいろ見えてきます。
例えば、c++のstd::sqrtを0x10000000回使用した場合の処理時間は1346.36でした。
luajitで同じ回数計算してみると、処理時間は520ぐらいです。
code:lua
<?
local d=require("M_Blur_Module")
local t1=d.time()
local s=0
for i=0,0xfffffff do
s=s+math.sqrt(i)
end
debug_print(string.format("%f : %f",d.time()-t1,s))
?>
このことから、最適化をしていないdllよりもluajitの方が速いということが分かると思います。速度面を重要視してdllを書くのであれば、SIMD命令やマルチスレッドなど何かしらの方法で最適化しないと意味がないということです。
むすび
本当はAviUtlのmath.sqrtをfast sqrt algorithmで高速化したかったのですが、できなかったので変な終着点で不時着しているような状態になってしまいました。計画性、大事です。
by metaphysical bard