Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 54 additions & 18 deletions linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@
#include <string.h>
#include <stdarg.h>
#include <locale.h>

#if __unix__
# include <unistd.h> // For _POSIX_VERSION
#endif

#include "linear.h"
#include "tron.h"

typedef signed char schar;
template <class T> static inline void swap(T& x, T& y) { T t=x; x=y; y=t; }
#ifndef min
Expand Down Expand Up @@ -2713,6 +2719,49 @@ static const char *solver_type_table[]=
"L2R_L2LOSS_SVR", "L2R_L2LOSS_SVR_DUAL", "L2R_L1LOSS_SVR_DUAL", NULL
};

#if _POSIX_VERSION >= 200809L

// If possible, use the thread-safe uselocale() function
typedef locale_t locale_handle;

static locale_handle set_c_locale()
{
locale_handle c_locale = newlocale(LC_ALL_MASK, "C", 0);
locale_handle old_locale = uselocale(c_locale);
return old_locale;
}

static void restore_locale(locale_handle locale)
{
locale_handle c_locale = uselocale(locale);
if (c_locale && c_locale != LC_GLOBAL_LOCALE) {
freelocale(c_locale);
}
}

#else

// But fall back to setlocale() if uselocale() is not available
typedef char *locale_handle;

static locale_handle set_c_locale()
{
locale_handle old_locale = setlocale(LC_ALL, NULL);
if (old_locale) {
old_locale = strdup(old_locale);
}
setlocale(LC_ALL, "C");
return old_locale;
}

static void restore_locale(locale_handle locale)
{
setlocale(LC_ALL, locale);
free(locale);
}

#endif

int save_model(const char *model_file_name, const struct model *model_)
{
int i;
Expand All @@ -2728,12 +2777,7 @@ int save_model(const char *model_file_name, const struct model *model_)
FILE *fp = fopen(model_file_name,"w");
if(fp==NULL) return -1;

char *old_locale = setlocale(LC_ALL, NULL);
if (old_locale)
{
old_locale = strdup(old_locale);
}
setlocale(LC_ALL, "C");
locale_handle old_locale = set_c_locale();

int nr_w;
if(model_->nr_class==2 && model_->param.solver_type != MCSVM_CS)
Expand Down Expand Up @@ -2765,8 +2809,7 @@ int save_model(const char *model_file_name, const struct model *model_)
fprintf(fp, "\n");
}

setlocale(LC_ALL, old_locale);
free(old_locale);
restore_locale(old_locale);

if (ferror(fp) != 0 || fclose(fp) != 0) return -1;
else return 0;
Expand All @@ -2790,10 +2833,9 @@ int save_model(const char *model_file_name, const struct model *model_)
// EXIT_LOAD_MODEL should NOT end with a semicolon.
#define EXIT_LOAD_MODEL()\
{\
setlocale(LC_ALL, old_locale);\
restore_locale(old_locale);\
free(model_->label);\
free(model_);\
free(old_locale);\
return NULL;\
}
struct model *load_model(const char *model_file_name)
Expand All @@ -2811,12 +2853,7 @@ struct model *load_model(const char *model_file_name)

model_->label = NULL;

char *old_locale = setlocale(LC_ALL, NULL);
if (old_locale)
{
old_locale = strdup(old_locale);
}
setlocale(LC_ALL, "C");
locale_handle old_locale = set_c_locale();

char cmd[81];
while(1)
Expand Down Expand Up @@ -2898,8 +2935,7 @@ struct model *load_model(const char *model_file_name)
}
}

setlocale(LC_ALL, old_locale);
free(old_locale);
restore_locale(old_locale);

if (ferror(fp) != 0 || fclose(fp) != 0) return NULL;

Expand Down