diff --git a/src/lua/ai.c b/src/lua/ai.c index 63702785ca93..5c20087d1c93 100644 --- a/src/lua/ai.c +++ b/src/lua/ai.c @@ -414,6 +414,64 @@ static int _tensor_scale_add(lua_State *L) return 1; } +// tensor:scale_add_planes({scales} [, {offsets}]) → self +// in-place per-plane: t[k,...] = t[k,...] * scales[k+1] + offsets[k+1]. +// 1-indexed tables, one entry per channel plane (dim 1). offsets +// defaults to 0. Requires ndim >= 2 and shape[0] == 1. +static int _tensor_scale_add_planes(lua_State *L) +{ + dt_lua_ai_tensor_t *t + = luaL_checkudata(L, 1, "dt_lua_ai_tensor_t"); + if(!t->data) return luaL_error(L, "tensor has been freed"); + luaL_checktype(L, 2, LUA_TTABLE); + const gboolean has_offsets = !lua_isnoneornil(L, 3); + if(has_offsets) luaL_checktype(L, 3, LUA_TTABLE); + if(t->ndim < 2 || t->shape[0] != 1) + return luaL_error(L, "scale_add_planes requires [1,C,...] tensor"); + const size_t C = (size_t)t->shape[1]; + + float *scales = g_new0(float, C); + float *offsets = g_new0(float, C); + for(size_t k = 0; k < C; k++) + { + lua_rawgeti(L, 2, (int)(k + 1)); + if(!lua_isnumber(L, -1)) + { + g_free(scales); g_free(offsets); + return luaL_error(L, "scales[%d] is not a number", (int)(k + 1)); + } + scales[k] = (float)lua_tonumber(L, -1); + lua_pop(L, 1); + if(has_offsets) + { + lua_rawgeti(L, 3, (int)(k + 1)); + if(!lua_isnumber(L, -1)) + { + g_free(scales); g_free(offsets); + return luaL_error(L, "offsets[%d] is not a number", (int)(k + 1)); + } + offsets[k] = (float)lua_tonumber(L, -1); + lua_pop(L, 1); + } + } + + size_t plane = 1; + for(int d = 2; d < t->ndim; d++) plane *= (size_t)t->shape[d]; + + for(size_t k = 0; k < C; k++) + { + float *p = t->data + k * plane; + const float s = scales[k]; + const float o = offsets[k]; + for(size_t i = 0; i < plane; i++) p[i] = p[i] * s + o; + } + + g_free(scales); + g_free(offsets); + lua_settop(L, 1); + return 1; +} + // tensor:sum() → float // sum of all elements (double accumulation, returned as Lua number) static int _tensor_sum(lua_State *L) @@ -1672,6 +1730,8 @@ int dt_lua_init_ai(lua_State *L) lua_setfield(L, -2, "fill"); lua_pushcfunction(L, _tensor_scale_add); lua_setfield(L, -2, "scale_add"); + lua_pushcfunction(L, _tensor_scale_add_planes); + lua_setfield(L, -2, "scale_add_planes"); lua_pushcfunction(L, _tensor_sum); lua_setfield(L, -2, "sum"); lua_pushcfunction(L, _tensor_mean);