"""
Scientific.IO.NetCDF-like wrapper for python shelve.

I wrote this module because;
a) I wanted to use the familiar ScientificPython NetCDF interface 
   for I/O of Numeric arrays even where NetCDF was not installed, and
b) I was frustrated with the lack of support for built-in compression in
   the NetCDF library. 
   
This module uses bzip2 compression to create files that are
often more than 3 times smaller than true NetCDF files.

Shelve archives created with this module can easily be converted
to and from true NetCDF files using the shltonc and nctoshl class methods.

Differences with Scientific.IO.NetCDF:

1) data must be assigned to NetCDFVariable object along first
   dimension only. For example,
   x[0] = array[0,:,:] or x[0,:,:] = array[0,:,:] is OK 
   (where x is a NetCDFVariable object), but x[0,1,:] = array[0,1,:]
   will raise an exception (IndexError).
   However, data can be retrieved from a NetCDFVariable objects by slicing
   along any dimension.
2) only the first dimension can be unlimited.
3) Data is archived in a directory, with each variable
   in a separate file.  Global variable and attribute information
   is also stored in a separate file.
3) bzip2 compression is used for multi-dimensional variables
   (files are much smaller than netCDF, usually by a factor of 2-3)

Example code is in the 'examples' directory of the source distribution.

Requires Python >= 2.3 with bsddb and bz2 standard library modules.

Jeffrey Whitaker <jeffrey.s.whitaker@noaa.gov>

Version: 20050822
"""
#Copyright 2004 by Jeffrey Whitaker.

#Permission to use, copy, modify, and distribute this software and its
#documentation for any purpose and without fee is hereby granted,
#provided that the above copyright notice appear in all copies and that
#both that copyright notice and this permission notice appear in
#supporting documentation.

#THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE,
#INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO
#EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR
#CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF
#USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
#OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
#PERFORMANCE OF THIS SOFTWARE.

import shelve, sys, sys, os, math

try:
    import Numeric
except:
    raise ImportError, 'requires Numeric (http://sf.net/projects/numpy)'
try:
    import bsddb
except:
    raise ImportError, 'requires bsddb standard library module'
try:
    import bz2
except:
    raise ImportError, 'requires bz2 standard library module'
try:
    from Scientific.IO import NetCDF
    hasnetcdf = True
except:
    hasnetcdf = False

# max sequence index.
class _MaxSlice:
    def __init__(self):
        self.value = self[:]
    def __getitem__(self,key):
        return key.stop
_maxslice = _MaxSlice().value

def _mkdir(newdir):
    """works the way a good mkdir should :)
        - already exists, silently complete
        - regular file in the way, raise an exception
        - parent directory(ies) does not exist, make them as well
    """
    if os.path.isdir(newdir):
        pass
    elif os.path.isfile(newdir):
        raise OSError("a file with the same name as the desired " \
                      "dir, '%s', already exists." % newdir)
    else:
        head, tail = os.path.split(newdir)
        if head and not os.path.isdir(head):
            _mkdir(head)
        if tail:
            os.mkdir(newdir)

def _extractslice(k,lendim):
    """extract start,stop,step from a slice object, given the 
    length object being sliced"""
    if k.start == None:
       start = 0
    else:
       start = k.start
    if k.stop == None or k.stop == _maxslice:
       stop = lendim
    else:
       stop = k.stop
    if k.step == None:
       step = 1
    else:
       step = k.step
    return start,stop,step

def _sliceparams(key,shape):
    """given a slice object and the shape of the array being
    sliced, return a list of (start,stop,step) tuples"""
    if isinstance(key,tuple):
        slices = []
        for n,k in enumerate(key):
            slices.append(_extractslice(k,shape[n]))
        return slices
    else:
        return [_extractslice(key,shape[0])]

def _quantize(data,least_significant_digit):
    """quantize data to improve compression.
    data is quantized using Numeric.around(scale*data)/scale,
    where scale is 2**bits, and bits is determined from
    the least_significant_digit.
    For example, if least_significant_digit=1, bits will be 4."""
    precision = 10.**-least_significant_digit
    exp = math.log(precision,10)
    if exp < 0:
        exp = int(math.floor(exp))
    else:
        exp = int(math.ceil(exp))
    bits = math.ceil(math.log(10.**-exp,2))
    scale = 2.**bits
    typecode = data.typecode()
    return (Numeric.around(scale*data)/scale).astype(typecode)

class NetCDFFile:
    """
    netCDF file Constructor: NetCDFFile(filename, mode="r",history=None)

    Arguments:

    filename -- Name of directory to hold data archive.
                Separate shelve files are created in this directory for
                each variable (plus one containing global attributes
                and variable info).

    mode -- access mode. "r" means read-only; no data can be modified.
            "w" means write; a new file is created, an existing
            file with the same name is deleted. "a" means append
            (in analogy with serial files); an existing file is
            opened for reading and writing.

    history -- a string that is used to define the global NetCDF
    attribute 'history'.

    A NetCDFFile object has two standard attributes: 'dimensions' and
    'variables'. The values of both are dictionaries, mapping dimension
    names to their associated lengths and variable names to variables,
    respectively. Application programs should never modify these
    dictionaries.

    All other attributes correspond to global attributes defined in the
    netCDF file. Global file attributes are created by assigning to
    an attribute of the NetCDFFile object. 
    """

    def __init__(self,dirname,mode='r',history=None):
        # make the data directory (if it doesn't already exist)     
        _mkdir(dirname)
        self.__dirname = dirname
        flag = mode # shelf is read-only be default
        if mode == 'a': flag='w' # append to an existing shelf
        if mode == 'w': flag='n' # create a new shelf
        # use writeback=True to write to shelf.
        if mode == 'r':
           writeback = False
        else:
           writeback = True
        # open contents shelve (holds global attributes and variable info).
        filename = os.path.join(dirname,'contents.shl')
        self.__shelf = shelve.open(filename,flag=flag,protocol=-1,writeback=writeback)
        # initialize dimension and variable dictionaries for a new shelf
        self.__mode = mode
        if mode == 'w':
            self.dimensions = {}
            self.variables = {}
            self.varinfo = {}
            self.__shelf['unlimdim_val'] = 0
        # retrieve information from an existing shelf.
        elif mode == 'r' or mode == 'a':
            # check to see if shelf exists.
            if len(self.__shelf.keys()) == 0:
                raise TypeError, '%s is empty or does not exist - use mode="w" instead' % filename
            # open all variables.
            self.variables={}
            for varname in self.varinfo.keys():
                dimensions = self.varinfo[varname]['dimensions']
                datatype = self.varinfo[varname]['typecode']
                self.variables[varname] = NetCDFVariable(varname,self,datatype,dimensions)
        # set history attribute.
        if mode != 'r' and history != None:
            self.history = history

    def createDimension(self,dimname,size):
        """Creates a new dimension with the given "dimname" and
        "size". "size" must be a positive integer or 'None',
        which stands for the unlimited dimension. There can
        be only one unlimited dimension per dataset."""
        self.dimensions[dimname] = size
        # make sure there is only one unlimited dimension.
        if self.dimensions.values().count(None) > 1:
            raise ValueError, 'only one unlimited dimension allowed!'

    def createVariable(self,varname,datatype,dimensions,least_significant_digit=None):
        """Creates a new variable with the given "varname", "datatype", and
        "dimensions". The "datatype" is a one-letter string with the same
        meaning as the typecodes for arrays in module Numeric; in
        practice the predefined type constants from Numeric should
        be used. "dimensions" must be a tuple containing dimension
        names (strings) that have been defined previously.
        An unlimited dimension must be the first (leftmost)
        dimension of the variable.

        If the optional keyword parameter 'least_significant_digit' is
        specified, multidimensional variables will be truncated (quantized).
        This can significantly improve compression.  For example, if 
        least_significant_digit=1, data will be quantized using
        Numeric.around(scale*data)/scale, where scale = 2**bits, and
        bits is determined so that a precision of 0.1 is retained (in 
        this case bits=4). 
        From http://www.cdc.noaa.gov/cdc/conventions/cdc_netcdf_standard.shtml:
        "least_significant_digit -- power of ten of the smallest decimal
        place in unpacked data that is a reliable value."

        The return value is the NetCDFVariable object describing the
        new variable."""
        # make sure unlimited dimension, if it exists, is the first dimension.
        for ndim,dim in enumerate(dimensions):
            dimsize = self.__shelf['dimensions'][dim]
            if dimsize == None and ndim != 0:
                raise ValueError, 'unlimited dimension must be first dimension of variable!'
        # create NetCDFVariable instance.
        var = NetCDFVariable(varname,self,datatype,dimensions,least_significant_digit=least_significant_digit)
        # update shelf variable dictionary, global variable 
        # info dict.
        self.variables[varname] = var
        self.varinfo[varname] = {}
        self.varinfo[varname]['dimensions']=dimensions
        self.varinfo[varname]['typecode']=datatype
        return var

    def close(self):
        """Closes the file. Any read or write access to the file
        or one of its variables after closing raises an exception."""
        for var in self.variables.keys():
            self.variables[var]._close()
        self.__shelf.close()

    def sync(self):
        "Writes all buffered data to the disk file."
        self.__shelf.sync()

    def __repr__(self):
        """produces output similar to 'ncdump -h'"""
        info=[self.__dirname+' {\n']
        info.append('dimensions:\n')
        for key,val in self.dimensions.iteritems():
            dim = self.__shelf['dimensions'][key]
            if dim == None:
                size = self.__shelf['unlimdim_val']
                info.append('    '+key+' = UNLIMITED ; // ('+repr(size)+' currently)\n')
            else:
                info.append('    '+key+' = '+repr(val)+' ;\n')
        info.append('variables:\n')
        for varname in self.varinfo.keys():
            dim = self.varinfo[varname]['dimensions']

            datatype = self.variables[varname].typecode()
            if datatype == 's':
                type = 'short'
            elif datatype == 'b':
                type = 'byte'
            elif datatype == 'l' or datatype == 'i':
                type = 'int'
            elif datatype == 'f':
                type = 'float'
            elif datatype == 'd':
                type = 'double'
            info.append('    '+type+' '+varname+str(dim)+' ;\n')
            for key,val in self.variables[varname]._NetCDFVariable__varshelf['attributes'].iteritems():
                if key not in ['dimensions','shape']:
                    info.append('        '+varname+':'+key+' = '+repr(val)+' ;\n')
        info.append('// global attributes:\n')
        for key,val in self.__shelf.iteritems():
            if key not in ['varinfo','dimensions','unlimdim_val']:
                info.append('        :'+key+' = '+repr(val)+' ;\n')
        info.append('}')
        return ''.join(info)

    def __setattr__(self,name,value):
       self.__dict__[name] = value
       prefix = '_'+self.__class__.__name__
       if not name.startswith(prefix) and name != 'variables' and self.__mode != 'r':
           self.__shelf[name] = value

    def __getattr__(self,name):
       if name.startswith('__') and name.endswith('__'):
           raise AttributeError
       else:
           return self.__shelf[name]

    def shltonc(self,filename,packshort=False,scale_factor=None,add_offset=None):
        """convert NetCDFShelf.NetCDFFile to a true netcdf file (filename).
        Requires Scientific.IO.NetCDF module. If packshort=True, 
        variables are packed as short integers using the dictionaries
        scale_factor and add_offset. The dictionary keys are the 
        the variable names in the shelve archive to be packed as short
        integers."""
        if not hasnetcdf:
            print 'Scientific.IO.NetCDF must be installed to convert shelve to NetCDF'
            return
        ncfile = NetCDF.NetCDFFile(filename,'w')
        # create dimensions.
        for dimname,size in self.dimensions.iteritems():
            ncfile.createDimension(dimname,size)
        # create variables.     
        for varname in self.varinfo.keys():
            dims = self.varinfo[varname]['dimensions']
            packvar = False
            if packshort and scale_factor.has_key(varname) and add_offset.has_key(varname):
                print 'packing %s as short integers ...'%(varname)
                datatype = 's'
                packvar = True
            else:
                datatype = self.varinfo[varname]['typecode']
            var = ncfile.createVariable(varname,datatype,dims)
            for key,val in self.variables[varname]._NetCDFVariable__varshelf['attributes'].iteritems():
                if key not in ['shape','dimensions']:
                    setattr(var,key,val)
                if packvar:
                    setattr(var,'scale_factor',scale_factor[varname])
                    setattr(var,'add_offset',add_offset[varname])
            for n in range(self.variables[varname].shape[0]):
                if packvar:
                    var[n] = ((1./scale_factor[varname])*(self.variables[varname][n] - add_offset[varname])).astype('s')
                else:
                    var[n] = self.variables[varname][n]
        # create global attributes.
        for key,val in self.__shelf.iteritems():
            if key not in ['varinfo','dimensions','unlimdim_val']:
                setattr(ncfile,key,val)
        # close file.
        ncfile.close()

    def nctoshl(self,filename,unpackshort=True):
        """convert a true netcdf file (filename) to a NetCDFShelf.NetCDFFile.
        Requires Scientific.IO.NetCDF module. If unpackshort=True, variables
        stored as short integers with a scale and offset are unpacked to floating
        point variables in the netCDF-shelve archive.  If the least_significant_digit
        attribute is set, the data is quantized to improve compression."""
        if not hasnetcdf:
            print 'Scientific.IO.NetCDF must be installed to convert shelve to NetCDF'
            return
        ncfile = NetCDF.NetCDFFile(filename,'r')
        # create dimensions.
        for dimname,size in ncfile.dimensions.iteritems():
            self.createDimension(dimname,size)
        # create variables.     
        for varname,ncvar in ncfile.variables.iteritems():
            if hasattr(ncvar,'least_significant_digit'):
                lsd = ncvar.least_significant_digit
            else:
                lsd = None
            if unpackshort and hasattr(ncvar,'scale_factor') and hasattr(ncvar,'add_offset'):
                dounpackshort = True
                datatype = 'f'
            else:
                dounpackshort = False
                datatype = ncvar.typecode()
            var = self.createVariable(varname,datatype,ncvar.dimensions,least_significant_digit=lsd)
            for key,val in ncvar.__dict__.iteritems():
                if dounpackshort and key in ['add_offset','scale_factor']: continue
                if dounpackshort and key == 'missing_value': val=1.e30
                if isinstance(val,type(Numeric.array([1]))) and len(val)==1:
                    val = val[0]
                setattr(var,key,val)
            for n in range(ncvar.shape[0]):
                if dounpackshort:
                    idata = ncvar[n]
                    fdata = (ncvar.scale_factor*idata+ncvar.add_offset).astype('f')
                    if hasattr(ncvar,'missing_value'):
                        fdata = Numeric.where(idata >= ncvar.missing_value, 1.e30, fdata)
                    var[n] = fdata
                else:
                    var[n] = ncvar[n]
        # create global attributes.
        for key,val in ncfile.__dict__.iteritems():
            if isinstance(val,type(Numeric.array([1]))) and len(val)==1:
               val = val[0]
            setattr(self,key,val)
        # close file.
        ncfile.close()

class NetCDFVariable:
    """Variable in a netCDF file

    NetCDFVariable objects are constructed by calling the method
    'createVariable' on the NetCDFFile object.

    NetCDFVariable objects behave much like array objects defined
    in module Numeric, except that their data resides in a file.
    Data is read by indexing and written by assigning to an
    indexed subset; the entire array can be accessed by the index
    '[:]'. Unlike Numeric arrays (and Scientific.IO.NetCDFVariable
    objects), data can be assigned to NetCDFVariable objects only by 
    indexing along the first (leftmost) dimension. Data may be
    retrieved from a NetCDFVariable object by slicing along any
    dimension.  Only 'extended slicing' of the the form i:j:k is 
    supported, not fully general Numeric array slicing (for example,
    a[::-1] won't return the contents of the variable a reversed).

    Multidimensional variables are compressed on disk using bz2 compression,
    after optionally truncating to a precision specified by the
    least_significant_digit keyword argument to createVariable.
    Truncation will signficantly improve compression.

    NetCDFVariable objects also have attribute
    "shape" with the same meaning as for arrays, but the shape
    cannot be modified. There is another read-only attribute
    "dimensions", whose value is the tuple of dimension names.

    All other attributes correspond to variable attributes defined in the
    netCDF file. Variable attributes are created by assigning to
    an attribute of the NetCDFVariable object. 
    """

    def __init__(self, varname, shelf, datatype, dimensions, least_significant_digit=None):
       self.__datatype = datatype
       self.__varname = varname
       self.__shelf = shelf._NetCDFFile__shelf
       self.__mode = shelf._NetCDFFile__mode
       flag = self.__mode # shelf is read-only be default
       if self.__mode == 'a': flag='w' # append to an existing shelf
       if self.__mode == 'w': flag='n' # create a new shelf
       if self.__mode == 'r':
          writeback = False
       else:
          writeback = True
       filename = os.path.join(shelf._NetCDFFile__dirname,varname+'.shl')
       self.__varshelf = shelve.open(filename,flag=flag,protocol=-1,writeback=writeback)
       if self.__mode == 'w':
           shape = []
           for key in dimensions:
               value = self.__shelf['dimensions'][key]
               if value == None:
                   shape.append(0)
               else:
                   shape.append(value)
           self.__varshelf['attributes'] = {}
           self.__varshelf['attributes']['dimensions']=dimensions
           self.__varshelf['attributes']['shape']=tuple(shape)
           if least_significant_digit != None:
               self.__varshelf['attributes']['least_significant_digit'] = least_significant_digit
           self.__datatype = datatype

    def _close(self):
       """close the NetCDFVariable instance"""
       self.__varshelf.close()

    def typecode(self):
       """Return the variable's type code (a string)."""
       return self.__datatype

    def __setitem__(self,key,data):
        # allow only a single key.  If key is a tuple, key[1:}
        # must be a slice with start=stop=step=None.
        if isinstance(key,tuple):
            if len(key) > 1:
                for k in key[1:]:
                    if not isinstance(k,slice) or k.start!=None or k.stop!=None or k.step!=None:
                        raise IndexError, 'must write all data along 1st dimension at once'
                if len(key) != len(self.shape):
                    raise IndexError, 'you are using indices that are the wrong shape for the netCDF variable object'
                key = key[0]
        shape = list(self.shape)
        # force typecast of data to specified typecode for variable
        try:
            data = data.astype(self.typecode())
        except: # if not a Numeric array, just continue.
            pass
        # quantize data if least_significant digit attribute set 
        # and variable is multi-dimensional.
        if len(shape) > 1 and hasattr(self,'least_significant_digit'):
            data = _quantize(data,self.least_significant_digit)
        firstdim = self.__shelf['dimensions'][self.dimensions[0]]
        # key is a slice object.
        if isinstance(key,slice):
            sliceparam_list = _sliceparams(key,shape)
            start,stop,step = sliceparam_list[0]
            indices = range(start,stop,step)
            if len(shape) == 1: # rank-1 array
                for n,k in enumerate(indices):
                    if not self.__varshelf.has_key(str(k)):
                        if firstdim==None:
                            if k > shape[0]:
                                raise KeyError, 'last index is %s, your are trying to append at %s'%(shape[0]-1,k)
                            shape[0]=shape[0]+1
                        elif k > shape[0]-1:
                            raise KeyError, 'key out of range'
                    self.__varshelf[str(k)] = data[n]
            else: # multi-dimensional array.
                for n,k in enumerate(indices):
                    if data[n].shape != self.shape[1:]:
                        raise ValueError, 'data does not conform to shape of netCDF variable object'
                    if not self.__varshelf.has_key(str(k)):
                        if firstdim==None:
                            if k > shape[0]:
                                raise KeyError, 'last index is %s, your are trying to append at %s'%(shape[0]-1,k)
                            shape[0]=shape[0]+1
                        elif k > shape[0]-1:
                            raise KeyError, 'key out of range'
                    self.__varshelf[str(k)] = bz2.compress(Numeric.ravel(data[n]).tostring())
        # single key assignment
        else:
           if key < 0: key = key + shape[0]
           if len(shape) == 1:
               if not self.__varshelf.has_key(str(key)):
                  if firstdim==None:
                      if key > shape[0]:
                          raise KeyError, 'last index is %s, your are trying to append at %s'%(shape[0]-1,key)
                      shape[0]=shape[0]+1
                  elif key > shape[0]-1:
                      raise KeyError, 'key out of range'
               self.__varshelf[str(key)]=data
           else:
               if data.shape != self.shape[1:]:
                   print data.shape, self.shape
                   raise ValueError, 'data does not conform to shape of netCDF variable object'
               if not self.__varshelf.has_key(str(key)):
                  if firstdim==None:
                      if key > shape[0]:
                          raise KeyError, 'last index is %s, your are trying to append at %s'%(shape[0]-1,key)
                      shape[0]=shape[0]+1
                  elif key > shape[0]-1:
                      raise KeyError, 'key out of range'
               self.__varshelf[str(key)] = bz2.compress(Numeric.ravel(data).tostring())
        if firstdim == None: self.__shelf['unlimdim_val']=shape[0]
        self.__varshelf['attributes']['shape'] = tuple(shape)

    def __getitem__(self,key):
        shape = list(self.shape)
        if key == 'attributes':
           return self.__varshelf['attributes'][key]
        # key is a slice object, or a list of slice objects.
        if isinstance(key,slice) or (isinstance(key,tuple) and isinstance(key[0],slice)):
            sliceparam_list = _sliceparams(key,shape)
            # unpack all data for requested items along first dimension.
            start,stop,step = sliceparam_list[0]
            indices = range(start,stop,step)
            if len(shape) == 1: # rank-1 array
                data = Numeric.zeros(len(indices),self.typecode())
                for n,k in enumerate(indices):
                    data[n]=self.__varshelf[str(k)]
            else: # multi-dimensional array
                data = Numeric.zeros([len(indices)]+shape[1:],self.typecode())
                for n,k in enumerate(indices):
                    data[n]=Numeric.reshape(Numeric.fromstring(bz2.decompress(self.__varshelf[str(k)]),self.typecode()),shape[1:])
            if len(sliceparam_list) == 1: # return all data.
                return data
            else: # return slice of unpacked data.
                return data[[slice(0,_maxslice,1)]+[slice(start,stop,step) for start,stop,step in sliceparam_list[1:]]]
        # key is a tuple.
        elif isinstance(key,tuple):
            return Numeric.reshape(Numeric.fromstring(bz2.decompress(self.__varshelf[str(key[0])]),self.typecode()),self.shape[1:])[key[1:]]
        # single key retrieval
        else:
            if key < 0: key = key + shape[0]
            if len(shape) == 1: # rank-1 array
                return self.__varshelf[str(key)]
            else: # multi-dimensional array. 
                return Numeric.reshape(Numeric.fromstring(bz2.decompress(self.__varshelf[str(key)]),self.typecode()),self.shape[1:])

    def __len__(self):
       return self.shape[0]

    def __setattr__(self,name,value):
       if name in ['dimensions','shape']:
           raise KeyError, '"dimensions" and "shape" are read-only attributes - cannot modify'
       self.__dict__[name] = value
       prefix = '_'+self.__class__.__name__
       if not name.startswith(prefix) and self.__mode != 'r':
           self.__varshelf['attributes'][name] = value

    def __getattr__(self,name):
       if name.startswith('__') and name.endswith('__'):
           raise AttributeError
       else:
           return self.__varshelf['attributes'][name]

    def append(self,data):
        """append data along unlimited dimension."""
        self[self.shape[0]] = data