1515import struct
1616from dateutil .relativedelta import relativedelta
1717from pandas .core .base import StringMixin
18+ from pandas .core .categorical import Categorical
1819from pandas .core .frame import DataFrame
1920from pandas .core .series import Series
20- from pandas .core .categorical import Categorical
2121import datetime
2222from pandas import compat , to_timedelta , to_datetime , isnull , DatetimeIndex
2323from pandas .compat import lrange , lmap , lzip , text_type , string_types , range , \
24- zip
24+ zip , BytesIO
2525import pandas .core .common as com
2626from pandas .io .common import get_filepath_or_buffer
2727from pandas .lib import max_len_string_array , infer_dtype
@@ -336,6 +336,15 @@ class PossiblePrecisionLoss(Warning):
336336conversion range. This may result in a loss of precision in the saved data.
337337"""
338338
339+ class ValueLabelTypeMismatch (Warning ):
340+ pass
341+
342+ value_label_mismatch_doc = """
343+ Stata value labels (pandas categories) must be strings. Column {0} contains
344+ non-string labels which will be converted to strings. Please check that the
345+ Stata data file created has not lost information due to duplicate labels.
346+ """
347+
339348
340349class InvalidColumnName (Warning ):
341350 pass
@@ -425,6 +434,131 @@ def _cast_to_stata_types(data):
425434 return data
426435
427436
437+ class StataValueLabel (object ):
438+ """
439+ Parse a categorical column and prepare formatted output
440+
441+ Parameters
442+ -----------
443+ value : int8, int16, int32, float32 or float64
444+ The Stata missing value code
445+
446+ Attributes
447+ ----------
448+ string : string
449+ String representation of the Stata missing value
450+ value : int8, int16, int32, float32 or float64
451+ The original encoded missing value
452+
453+ Methods
454+ -------
455+ generate_value_label
456+
457+ """
458+
459+ def __init__ (self , catarray ):
460+
461+ self .labname = catarray .name
462+
463+ categories = catarray .cat .categories
464+ self .value_labels = list (zip (np .arange (len (categories )), categories ))
465+ self .value_labels .sort (key = lambda x : x [0 ])
466+ self .text_len = np .int32 (0 )
467+ self .off = []
468+ self .val = []
469+ self .txt = []
470+ self .n = 0
471+
472+ # Compute lengths and setup lists of offsets and labels
473+ for vl in self .value_labels :
474+ category = vl [1 ]
475+ if not isinstance (category , string_types ):
476+ category = str (category )
477+ import warnings
478+ warnings .warn (value_label_mismatch_doc .format (catarray .name ),
479+ ValueLabelTypeMismatch )
480+
481+ self .off .append (self .text_len )
482+ self .text_len += len (category ) + 1 # +1 for the padding
483+ self .val .append (vl [0 ])
484+ self .txt .append (category )
485+ self .n += 1
486+
487+ if self .text_len > 32000 :
488+ raise ValueError ('Stata value labels for a single variable must '
489+ 'have a combined length less than 32,000 '
490+ 'characters.' )
491+
492+ # Ensure int32
493+ self .off = np .array (self .off , dtype = np .int32 )
494+ self .val = np .array (self .val , dtype = np .int32 )
495+
496+ # Total length
497+ self .len = 4 + 4 + 4 * self .n + 4 * self .n + self .text_len
498+
499+ def _encode (self , s ):
500+ """
501+ Python 3 compatability shim
502+ """
503+ if compat .PY3 :
504+ return s .encode (self ._encoding )
505+ else :
506+ return s
507+
508+ def generate_value_label (self , byteorder , encoding ):
509+ """
510+ Parameters
511+ ----------
512+ byteorder : str
513+ Byte order of the output
514+ encoding : str
515+ File encoding
516+
517+ Returns
518+ -------
519+ value_label : bytes
520+ Bytes containing the formatted value label
521+ """
522+
523+ self ._encoding = encoding
524+ bio = BytesIO ()
525+ null_string = '\x00 '
526+ null_byte = b'\x00 '
527+
528+ # len
529+ bio .write (struct .pack (byteorder + 'i' , self .len ))
530+
531+ # labname
532+ labname = self ._encode (_pad_bytes (self .labname [:32 ], 33 ))
533+ bio .write (labname )
534+
535+ # padding - 3 bytes
536+ for i in range (3 ):
537+ bio .write (struct .pack ('c' , null_byte ))
538+
539+ # value_label_table
540+ # n - int32
541+ bio .write (struct .pack (byteorder + 'i' , self .n ))
542+
543+ # textlen - int32
544+ bio .write (struct .pack (byteorder + 'i' , self .text_len ))
545+
546+ # off - int32 array (n elements)
547+ for offset in self .off :
548+ bio .write (struct .pack (byteorder + 'i' , offset ))
549+
550+ # val - int32 array (n elements)
551+ for value in self .val :
552+ bio .write (struct .pack (byteorder + 'i' , value ))
553+
554+ # txt - Text labels, null terminated
555+ for text in self .txt :
556+ bio .write (self ._encode (text + null_string ))
557+
558+ bio .seek (0 )
559+ return bio .read ()
560+
561+
428562class StataMissingValue (StringMixin ):
429563 """
430564 An observation's missing value.
@@ -477,25 +611,31 @@ class StataMissingValue(StringMixin):
477611 for i in range (1 , 27 ):
478612 MISSING_VALUES [i + b ] = '.' + chr (96 + i )
479613
480- base = b'\x00 \x00 \x00 \x7f '
614+ float32_base = b'\x00 \x00 \x00 \x7f '
481615 increment = struct .unpack ('<i' , b'\x00 \x08 \x00 \x00 ' )[0 ]
482616 for i in range (27 ):
483- value = struct .unpack ('<f' , base )[0 ]
617+ value = struct .unpack ('<f' , float32_base )[0 ]
484618 MISSING_VALUES [value ] = '.'
485619 if i > 0 :
486620 MISSING_VALUES [value ] += chr (96 + i )
487621 int_value = struct .unpack ('<i' , struct .pack ('<f' , value ))[0 ] + increment
488- base = struct .pack ('<i' , int_value )
622+ float32_base = struct .pack ('<i' , int_value )
489623
490- base = b'\x00 \x00 \x00 \x00 \x00 \x00 \xe0 \x7f '
624+ float64_base = b'\x00 \x00 \x00 \x00 \x00 \x00 \xe0 \x7f '
491625 increment = struct .unpack ('q' , b'\x00 \x00 \x00 \x00 \x00 \x01 \x00 \x00 ' )[0 ]
492626 for i in range (27 ):
493- value = struct .unpack ('<d' , base )[0 ]
627+ value = struct .unpack ('<d' , float64_base )[0 ]
494628 MISSING_VALUES [value ] = '.'
495629 if i > 0 :
496630 MISSING_VALUES [value ] += chr (96 + i )
497631 int_value = struct .unpack ('q' , struct .pack ('<d' , value ))[0 ] + increment
498- base = struct .pack ('q' , int_value )
632+ float64_base = struct .pack ('q' , int_value )
633+
634+ BASE_MISSING_VALUES = {'int8' : 101 ,
635+ 'int16' : 32741 ,
636+ 'int32' : 2147483621 ,
637+ 'float32' : struct .unpack ('<f' , float32_base )[0 ],
638+ 'float64' : struct .unpack ('<d' , float64_base )[0 ]}
499639
500640 def __init__ (self , value ):
501641 self ._value = value
@@ -518,6 +658,22 @@ def __eq__(self, other):
518658 return (isinstance (other , self .__class__ )
519659 and self .string == other .string and self .value == other .value )
520660
661+ @classmethod
662+ def get_base_missing_value (cls , dtype ):
663+ if dtype == np .int8 :
664+ value = cls .BASE_MISSING_VALUES ['int8' ]
665+ elif dtype == np .int16 :
666+ value = cls .BASE_MISSING_VALUES ['int16' ]
667+ elif dtype == np .int32 :
668+ value = cls .BASE_MISSING_VALUES ['int32' ]
669+ elif dtype == np .float32 :
670+ value = cls .BASE_MISSING_VALUES ['float32' ]
671+ elif dtype == np .float64 :
672+ value = cls .BASE_MISSING_VALUES ['float64' ]
673+ else :
674+ raise ValueError ('Unsupported dtype' )
675+ return value
676+
521677
522678class StataParser (object ):
523679 _default_encoding = 'cp1252'
@@ -1111,10 +1267,10 @@ def data(self, convert_dates=True, convert_categoricals=True, index=None,
11111267 umissing , umissing_loc = np .unique (series [missing ],
11121268 return_inverse = True )
11131269 replacement = Series (series , dtype = np .object )
1114- for i , um in enumerate (umissing ):
1270+ for j , um in enumerate (umissing ):
11151271 missing_value = StataMissingValue (um )
11161272
1117- loc = missing_loc [umissing_loc == i ]
1273+ loc = missing_loc [umissing_loc == j ]
11181274 replacement .iloc [loc ] = missing_value
11191275 else : # All replacements are identical
11201276 dtype = series .dtype
@@ -1390,6 +1546,45 @@ def _write(self, to_write):
13901546 else :
13911547 self ._file .write (to_write )
13921548
1549+ def _prepare_categoricals (self , data ):
1550+ """Check for categorigal columns, retain categorical information for
1551+ Stata file and convert categorical data to int"""
1552+
1553+ is_cat = [com .is_categorical_dtype (data [col ]) for col in data ]
1554+ self ._is_col_cat = is_cat
1555+ self ._value_labels = []
1556+ if not any (is_cat ):
1557+ return data
1558+
1559+ get_base_missing_value = StataMissingValue .get_base_missing_value
1560+ index = data .index
1561+ data_formatted = []
1562+ for col , col_is_cat in zip (data , is_cat ):
1563+ if col_is_cat :
1564+ self ._value_labels .append (StataValueLabel (data [col ]))
1565+ dtype = data [col ].cat .codes .dtype
1566+ if dtype == np .int64 :
1567+ raise ValueError ('It is not possible to export int64-based '
1568+ 'categorical data to Stata.' )
1569+ values = data [col ].cat .codes .values .copy ()
1570+
1571+ # Upcast if needed so that correct missing values can be set
1572+ if values .max () >= get_base_missing_value (dtype ):
1573+ if dtype == np .int8 :
1574+ dtype = np .int16
1575+ elif dtype == np .int16 :
1576+ dtype = np .int32
1577+ else :
1578+ dtype = np .float64
1579+ values = np .array (values , dtype = dtype )
1580+
1581+ # Replace missing values with Stata missing value for type
1582+ values [values == - 1 ] = get_base_missing_value (dtype )
1583+ data_formatted .append ((col , values , index ))
1584+
1585+ else :
1586+ data_formatted .append ((col , data [col ]))
1587+ return DataFrame .from_items (data_formatted )
13931588
13941589 def _replace_nans (self , data ):
13951590 # return data
@@ -1480,27 +1675,26 @@ def _check_column_names(self, data):
14801675 def _prepare_pandas (self , data ):
14811676 #NOTE: we might need a different API / class for pandas objects so
14821677 # we can set different semantics - handle this with a PR to pandas.io
1483- class DataFrameRowIter (object ):
1484- def __init__ (self , data ):
1485- self .data = data
1486-
1487- def __iter__ (self ):
1488- for row in data .itertuples ():
1489- # First element is index, so remove
1490- yield row [1 :]
14911678
14921679 if self ._write_index :
14931680 data = data .reset_index ()
1494- # Check columns for compatibility with stata
1495- data = _cast_to_stata_types (data )
1681+
14961682 # Ensure column names are strings
14971683 data = self ._check_column_names (data )
1684+
1685+ # Check columns for compatibility with stata, upcast if necessary
1686+ data = _cast_to_stata_types (data )
1687+
14981688 # Replace NaNs with Stata missing values
14991689 data = self ._replace_nans (data )
1500- self .datarows = DataFrameRowIter (data )
1690+
1691+ # Convert categoricals to int data, and strip labels
1692+ data = self ._prepare_categoricals (data )
1693+
15011694 self .nobs , self .nvar = data .shape
15021695 self .data = data
15031696 self .varlist = data .columns .tolist ()
1697+
15041698 dtypes = data .dtypes
15051699 if self ._convert_dates is not None :
15061700 self ._convert_dates = _maybe_convert_to_int_keys (
@@ -1515,6 +1709,7 @@ def __iter__(self):
15151709 self .fmtlist = []
15161710 for col , dtype in dtypes .iteritems ():
15171711 self .fmtlist .append (_dtype_to_default_stata_fmt (dtype , data [col ]))
1712+
15181713 # set the given format for the datetime cols
15191714 if self ._convert_dates is not None :
15201715 for key in self ._convert_dates :
@@ -1529,8 +1724,14 @@ def write_file(self):
15291724 self ._write (_pad_bytes ("" , 5 ))
15301725 self ._prepare_data ()
15311726 self ._write_data ()
1727+ self ._write_value_labels ()
15321728 self ._file .close ()
15331729
1730+ def _write_value_labels (self ):
1731+ for vl in self ._value_labels :
1732+ self ._file .write (vl .generate_value_label (self ._byteorder ,
1733+ self ._encoding ))
1734+
15341735 def _write_header (self , data_label = None , time_stamp = None ):
15351736 byteorder = self ._byteorder
15361737 # ds_format - just use 114
@@ -1585,9 +1786,15 @@ def _write_descriptors(self, typlist=None, varlist=None, srtlist=None,
15851786 self ._write (_pad_bytes (fmt , 49 ))
15861787
15871788 # lbllist, 33*nvar, char array
1588- #NOTE: this is where you could get fancy with pandas categorical type
15891789 for i in range (nvar ):
1590- self ._write (_pad_bytes ("" , 33 ))
1790+ # Use variable name when categorical
1791+ if self ._is_col_cat [i ]:
1792+ name = self .varlist [i ]
1793+ name = self ._null_terminate (name , True )
1794+ name = _pad_bytes (name [:32 ], 33 )
1795+ self ._write (name )
1796+ else : # Default is empty label
1797+ self ._write (_pad_bytes ("" , 33 ))
15911798
15921799 def _write_variable_labels (self , labels = None ):
15931800 nvar = self .nvar
@@ -1624,9 +1831,6 @@ def _prepare_data(self):
16241831 data_cols .append (data [col ].values )
16251832 dtype = np .dtype (dtype )
16261833
1627- # 3. Convert to record array
1628-
1629- # data.to_records(index=False, convert_datetime64=False)
16301834 if has_strings :
16311835 self .data = np .fromiter (zip (* data_cols ), dtype = dtype )
16321836 else :
0 commit comments