|
83 | 83 | * @note This file is included from py_nif.c (single compilation unit) |
84 | 84 | */ |
85 | 85 |
|
| 86 | +/* ============================================================================ |
| 87 | + * Callback Name Registry |
| 88 | + * |
| 89 | + * Maintains a C-side registry of registered callback function names. |
| 90 | + * This allows erlang_module_getattr to only return ErlangFunction wrappers |
| 91 | + * for actually registered functions, preventing introspection issues with |
| 92 | + * libraries like torch that probe module attributes. |
| 93 | + * ============================================================================ */ |
| 94 | + |
| 95 | +/** |
| 96 | + * @def CALLBACK_REGISTRY_BUCKETS |
| 97 | + * @brief Number of hash buckets for the callback registry |
| 98 | + */ |
| 99 | +#define CALLBACK_REGISTRY_BUCKETS 64 |
| 100 | + |
| 101 | +/** |
| 102 | + * @struct callback_name_entry_t |
| 103 | + * @brief Entry in the callback name registry hash table |
| 104 | + */ |
| 105 | +typedef struct callback_name_entry { |
| 106 | + char *name; /**< Callback name (owned) */ |
| 107 | + size_t name_len; /**< Length of name */ |
| 108 | + struct callback_name_entry *next; /**< Next entry in bucket chain */ |
| 109 | +} callback_name_entry_t; |
| 110 | + |
| 111 | +/** @brief Hash table buckets for callback registry */ |
| 112 | +static callback_name_entry_t *g_callback_registry[CALLBACK_REGISTRY_BUCKETS] = {NULL}; |
| 113 | + |
| 114 | +/** @brief Mutex protecting the callback registry */ |
| 115 | +static pthread_mutex_t g_callback_registry_mutex = PTHREAD_MUTEX_INITIALIZER; |
| 116 | + |
| 117 | +/** |
| 118 | + * @brief Simple hash function for callback names |
| 119 | + */ |
| 120 | +static unsigned int callback_name_hash(const char *name, size_t len) { |
| 121 | + unsigned int hash = 5381; |
| 122 | + for (size_t i = 0; i < len; i++) { |
| 123 | + hash = ((hash << 5) + hash) + (unsigned char)name[i]; |
| 124 | + } |
| 125 | + return hash % CALLBACK_REGISTRY_BUCKETS; |
| 126 | +} |
| 127 | + |
| 128 | +/** |
| 129 | + * @brief Check if a callback name is registered |
| 130 | + * |
| 131 | + * Thread-safe lookup in the callback registry. |
| 132 | + * |
| 133 | + * @param name Callback name to check |
| 134 | + * @param len Length of name |
| 135 | + * @return true if registered, false otherwise |
| 136 | + */ |
| 137 | +static bool is_callback_registered(const char *name, size_t len) { |
| 138 | + unsigned int bucket = callback_name_hash(name, len); |
| 139 | + bool found = false; |
| 140 | + |
| 141 | + pthread_mutex_lock(&g_callback_registry_mutex); |
| 142 | + |
| 143 | + callback_name_entry_t *entry = g_callback_registry[bucket]; |
| 144 | + while (entry != NULL) { |
| 145 | + if (entry->name_len == len && memcmp(entry->name, name, len) == 0) { |
| 146 | + found = true; |
| 147 | + break; |
| 148 | + } |
| 149 | + entry = entry->next; |
| 150 | + } |
| 151 | + |
| 152 | + pthread_mutex_unlock(&g_callback_registry_mutex); |
| 153 | + return found; |
| 154 | +} |
| 155 | + |
| 156 | +/** |
| 157 | + * @brief Register a callback name |
| 158 | + * |
| 159 | + * Thread-safe addition to the callback registry. |
| 160 | + * |
| 161 | + * @param name Callback name to register |
| 162 | + * @param len Length of name |
| 163 | + * @return 0 on success, -1 on failure |
| 164 | + */ |
| 165 | +static int register_callback_name(const char *name, size_t len) { |
| 166 | + /* Check if already registered */ |
| 167 | + if (is_callback_registered(name, len)) { |
| 168 | + return 0; /* Already registered, success */ |
| 169 | + } |
| 170 | + |
| 171 | + /* Allocate new entry */ |
| 172 | + callback_name_entry_t *entry = enif_alloc(sizeof(callback_name_entry_t)); |
| 173 | + if (entry == NULL) { |
| 174 | + return -1; |
| 175 | + } |
| 176 | + |
| 177 | + entry->name = enif_alloc(len + 1); |
| 178 | + if (entry->name == NULL) { |
| 179 | + enif_free(entry); |
| 180 | + return -1; |
| 181 | + } |
| 182 | + |
| 183 | + memcpy(entry->name, name, len); |
| 184 | + entry->name[len] = '\0'; |
| 185 | + entry->name_len = len; |
| 186 | + |
| 187 | + unsigned int bucket = callback_name_hash(name, len); |
| 188 | + |
| 189 | + pthread_mutex_lock(&g_callback_registry_mutex); |
| 190 | + |
| 191 | + entry->next = g_callback_registry[bucket]; |
| 192 | + g_callback_registry[bucket] = entry; |
| 193 | + |
| 194 | + pthread_mutex_unlock(&g_callback_registry_mutex); |
| 195 | + |
| 196 | + return 0; |
| 197 | +} |
| 198 | + |
| 199 | +/** |
| 200 | + * @brief Unregister a callback name |
| 201 | + * |
| 202 | + * Thread-safe removal from the callback registry. |
| 203 | + * |
| 204 | + * @param name Callback name to unregister |
| 205 | + * @param len Length of name |
| 206 | + */ |
| 207 | +static void unregister_callback_name(const char *name, size_t len) { |
| 208 | + unsigned int bucket = callback_name_hash(name, len); |
| 209 | + |
| 210 | + pthread_mutex_lock(&g_callback_registry_mutex); |
| 211 | + |
| 212 | + callback_name_entry_t **pp = &g_callback_registry[bucket]; |
| 213 | + while (*pp != NULL) { |
| 214 | + callback_name_entry_t *entry = *pp; |
| 215 | + if (entry->name_len == len && memcmp(entry->name, name, len) == 0) { |
| 216 | + *pp = entry->next; |
| 217 | + enif_free(entry->name); |
| 218 | + enif_free(entry); |
| 219 | + break; |
| 220 | + } |
| 221 | + pp = &entry->next; |
| 222 | + } |
| 223 | + |
| 224 | + pthread_mutex_unlock(&g_callback_registry_mutex); |
| 225 | +} |
| 226 | + |
| 227 | +/** |
| 228 | + * @brief Clean up the callback registry |
| 229 | + * |
| 230 | + * Frees all entries. Called during NIF unload. |
| 231 | + */ |
| 232 | +static void cleanup_callback_registry(void) { |
| 233 | + pthread_mutex_lock(&g_callback_registry_mutex); |
| 234 | + |
| 235 | + for (int i = 0; i < CALLBACK_REGISTRY_BUCKETS; i++) { |
| 236 | + callback_name_entry_t *entry = g_callback_registry[i]; |
| 237 | + while (entry != NULL) { |
| 238 | + callback_name_entry_t *next = entry->next; |
| 239 | + enif_free(entry->name); |
| 240 | + enif_free(entry); |
| 241 | + entry = next; |
| 242 | + } |
| 243 | + g_callback_registry[i] = NULL; |
| 244 | + } |
| 245 | + |
| 246 | + pthread_mutex_unlock(&g_callback_registry_mutex); |
| 247 | +} |
| 248 | + |
86 | 249 | /* ============================================================================ |
87 | 250 | * Suspended state management |
88 | 251 | * ============================================================================ */ |
@@ -1061,10 +1224,29 @@ static PyObject *ErlangFunction_call(ErlangFunctionObject *self, PyObject *args, |
1061 | 1224 |
|
1062 | 1225 | /** |
1063 | 1226 | * Module __getattr__ - enables "from erlang import func_name" and "erlang.func_name()" |
| 1227 | + * |
| 1228 | + * Only returns ErlangFunction wrapper for REGISTERED callback names. |
| 1229 | + * This prevents torch and other libraries that introspect module attributes |
| 1230 | + * from getting callable objects for arbitrary attribute names. |
1064 | 1231 | */ |
1065 | 1232 | static PyObject *erlang_module_getattr(PyObject *module, PyObject *name) { |
1066 | 1233 | (void)module; /* Unused */ |
1067 | | - /* Return an ErlangFunction wrapper for any attribute access */ |
| 1234 | + |
| 1235 | + /* Get the name as a C string */ |
| 1236 | + const char *name_str = PyUnicode_AsUTF8(name); |
| 1237 | + if (name_str == NULL) { |
| 1238 | + return NULL; /* Exception already set */ |
| 1239 | + } |
| 1240 | + size_t name_len = strlen(name_str); |
| 1241 | + |
| 1242 | + /* Check if this callback is registered */ |
| 1243 | + if (!is_callback_registered(name_str, name_len)) { |
| 1244 | + PyErr_Format(PyExc_AttributeError, |
| 1245 | + "module 'erlang' has no attribute '%s'", name_str); |
| 1246 | + return NULL; |
| 1247 | + } |
| 1248 | + |
| 1249 | + /* Return an ErlangFunction wrapper for registered callbacks */ |
1068 | 1250 | return ErlangFunction_New(name); |
1069 | 1251 | } |
1070 | 1252 |
|
@@ -1716,3 +1898,72 @@ static ERL_NIF_TERM nif_resume_callback_dirty(ErlNifEnv *env, int argc, const ER |
1716 | 1898 |
|
1717 | 1899 | return result; |
1718 | 1900 | } |
| 1901 | + |
| 1902 | +/* ============================================================================ |
| 1903 | + * NIF functions for callback name registration |
| 1904 | + * ============================================================================ */ |
| 1905 | + |
| 1906 | +/** |
| 1907 | + * @brief NIF to register a callback name in the C-side registry |
| 1908 | + * |
| 1909 | + * This allows the erlang module's __getattr__ to return ErlangFunction |
| 1910 | + * wrappers only for registered callbacks, preventing introspection issues. |
| 1911 | + * |
| 1912 | + * Args: Name (binary or atom) |
| 1913 | + * Returns: ok | {error, Reason} |
| 1914 | + */ |
| 1915 | +static ERL_NIF_TERM nif_register_callback_name(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { |
| 1916 | + (void)argc; |
| 1917 | + |
| 1918 | + ErlNifBinary name_bin; |
| 1919 | + char atom_buf[256]; |
| 1920 | + |
| 1921 | + const char *name; |
| 1922 | + size_t name_len; |
| 1923 | + |
| 1924 | + if (enif_inspect_binary(env, argv[0], &name_bin)) { |
| 1925 | + name = (const char *)name_bin.data; |
| 1926 | + name_len = name_bin.size; |
| 1927 | + } else if (enif_get_atom(env, argv[0], atom_buf, sizeof(atom_buf), ERL_NIF_LATIN1)) { |
| 1928 | + name = atom_buf; |
| 1929 | + name_len = strlen(atom_buf); |
| 1930 | + } else { |
| 1931 | + return make_error(env, "invalid_name"); |
| 1932 | + } |
| 1933 | + |
| 1934 | + if (register_callback_name(name, name_len) < 0) { |
| 1935 | + return make_error(env, "registration_failed"); |
| 1936 | + } |
| 1937 | + |
| 1938 | + return ATOM_OK; |
| 1939 | +} |
| 1940 | + |
| 1941 | +/** |
| 1942 | + * @brief NIF to unregister a callback name from the C-side registry |
| 1943 | + * |
| 1944 | + * Args: Name (binary or atom) |
| 1945 | + * Returns: ok |
| 1946 | + */ |
| 1947 | +static ERL_NIF_TERM nif_unregister_callback_name(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { |
| 1948 | + (void)argc; |
| 1949 | + |
| 1950 | + ErlNifBinary name_bin; |
| 1951 | + char atom_buf[256]; |
| 1952 | + |
| 1953 | + const char *name; |
| 1954 | + size_t name_len; |
| 1955 | + |
| 1956 | + if (enif_inspect_binary(env, argv[0], &name_bin)) { |
| 1957 | + name = (const char *)name_bin.data; |
| 1958 | + name_len = name_bin.size; |
| 1959 | + } else if (enif_get_atom(env, argv[0], atom_buf, sizeof(atom_buf), ERL_NIF_LATIN1)) { |
| 1960 | + name = atom_buf; |
| 1961 | + name_len = strlen(atom_buf); |
| 1962 | + } else { |
| 1963 | + return make_error(env, "invalid_name"); |
| 1964 | + } |
| 1965 | + |
| 1966 | + unregister_callback_name(name, name_len); |
| 1967 | + |
| 1968 | + return ATOM_OK; |
| 1969 | +} |
0 commit comments