require "arraydimindex"

class Rubber
  # a supplementary class for the "rubber dimension"
  # this class is supposed to yield only the constants defined below
  # These constants are defined at the top level, so they are available
  # globally.
end
  RUB=Rubber.new     # the rubber dimension
  ndmax=7            # Max rank
  skip=[] ; for i in 0..ndmax; skip=skip+[Rubber.new]; end
  SKIP=skip

####################################################################

class ArrayIndex

  def initialize(shape,*idx)
    rank=shape.length   # rank of the original array
    ##if idx.length == 1 && idx[0].is_a?(NumArray) && idx[0].numtype==Byte then
    if idx.length == 1 && idx[0].is_a?(NumArray) && idx[0].logical? then
      # assume a mask
      ntot=1; for i in 0..shape.length-1; ntot *= shape[i]; end
      mask=idx[0]
      @dimmap = [ArrayDimIndex.new(mask,ntot)]
      @stride = [1]
      @stridenew = [1]
#    elsif idx.length == 1 && rank != 1 then
#      # assume idx is an 1D array of incices to elements
#      ntot=1; for i in 0..shape.length-1; ntot *= shape[i]; end
#      @dimmap = [ArrayDimIndex.new(idx[0],ntot)]
#      @stride = [1]
#      @stridenew = [1]
    else
      # assume idx consits of index speifier to each dimension
      idx=rubber_intrprt(idx,rank)       # interpret rubber dimensions
      raise "\# of index specifiers != rank" if (idx.length != rank)
      @dimmap=[]
      for i in 0..rank-1
	@dimmap[i] = ArrayDimIndex.new(idx[i],shape[i])
      end
      @stride=[1]
      @stridenew = [1]
      for i in 1..rank-1
	@stride[i] = @stride[i-1]*shape[i-1]
	@stridenew[i] = @stridenew[i-1]*@dimmap[i-1].length
      end
    end
  end

  def crop(data)
    # cut the data array according to the subset specifcation in self,
    # and return it along with the shape of the new array.
    indx=@dimmap[0].indices
    for d in 1..@dimmap.length-1
      current_indx=indx.dup
      for j in 0..@dimmap[d].length-1
	indx[(@stridenew[d]*j)..(@stridenew[d]*(j+1)-1)] =
	  current_indx + @stride[d]*@dimmap[d].indices[j]
      end
    end
    newdata=data.indices(*indx)
    newshape=[]
    for d in 0..@dimmap.length-1
      newshape[d] = @dimmap[d].length
    end
    return newdata,newshape
  end

  def setval(data,val)
    # set val to the subset of data (data is modified directory)
    indx=@dimmap[0].indices
    for d in 1..@dimmap.length-1
      for j in 0..@dimmap[d].length-1
	indx[(@stridenew[d]*j)..(@stridenew[d]*(j+1)-1)] =
	  indx[0..@stridenew[d]-1] + @stride[d]*@dimmap[d].indices[j]
      end
    end
    if val.is_a?(Numeric) then
      for i in 0..indx.length-1; data[indx[i]]=val; end
    elsif val.is_a?(Array) then
      if val.length != indx.length then
	raise "lengths of the subset and new data do not agree"
      end
      for i in 0..indx.length-1; data[indx[i]]=val[i]; end
    elsif val.is_a?(NumArray) then
      if val.length < indx.length then
	raise "val.length is smaller than expected"
      end
      v=val.to_basic
      for i in 0..indx.length-1; data[indx[i]]=v[i]; end
    end
    true
  end

  ###################################################################
  private

  def rubber_intrprt(idx,rank)
    # private function
    # substitute rubber dimensions (RUB,SKIP[?]) with appropriate numbers
    # of 0..-1

    rubfound=nil
    idx.each{|i| rubfound=true if i.type == Rubber}

    if ! rubfound then
      return idx
    end

    ni=idx.length

    # (skip a speficified number of dimensions)

    for i in 1..SKIP.length-1
      while ( (irub=idx.index(SKIP[i])) != nil )
	idx[irub] = 0..-1
	ncn=i-1
	case irub
	when 0
	  for i in 0..ncn-1; idx.unshift(0..-1); end
	when ni-1
	  for i in 0..ncn-1; idx.push(0..-1); end
	else
	  idx = idx[0..irub-1] + [0..-1]*ncn + idx[irub..-1]
	end
	ni=ni+ncn
      end
    end
    
    while ( (irub=idx.index(SKIP[0])) != nil )
      idx.delete_at(irub)
      ni=ni-1
    end

    # real rubber dimension
    irub=idx.index(RUB)
    if irub != nil then
      # has rubber dimension(s)
      if (irub != idx.rindex(RUB)) then    # Two or more -> ambiguous
	raise(RuntimeError,"Two or more rubber dimension exists") 
      end
      if ni<=rank then
	idx[irub] = 0..-1
	ncn=rank-ni     # how many dimensions are contracted
	if ncn > 0 then
	  case irub
	  when 0
	    for i in 0..ncn-1; idx.unshift(0..-1); end
	  when ni-1
	    for i in 0..ncn-1; idx.push(0..-1); end
	  else
	    idx = idx[0..irub-1] + [0..-1]*ncn + idx[irub..-1]
	  end
	end
	ni=rank
      else
	idx.delete_at(irub)
	ni=ni-1
      end
    end
    idx
  end
  
end


if __FILE__ == $0
  require "numarray"

  p '*** test 1 ****'
  ary = NumArray.new(5,5,3)
  ary.span
  p ary
  ai = ArrayIndex.new(ary.shape, 1..2,0..2,{true,2})
  p ai
  data = ary.to_basic
  newdata,newshape = ai.crop(data)
  asub=NumArray.new(*newshape)
  asub.setvals(newdata)
  p asub.shape,asub


  p '*** test 2 **** (rubber dimension)'
  ary = NumArray.new(4,3,3,2)
  ary.span
  ai = ArrayIndex.new(ary.shape, RUB,0,SKIP[1])   # the 1st the 2nd D from last
  data = ary.to_basic
  newdata,newshape = ai.crop(data)
  asub=NumArray.new(*newshape)
  asub.setvals(newdata)
  p asub.shape,asub

  p '*** test 3 **** (1D indexing)'
  ary = NumArray.new(5,5,3)
  ary.span
  ai = ArrayIndex.new(ary.shape, {0..9,2})
  p ai
  data = ary.to_basic
  newdata,newshape = ai.crop(data)
  asub=NumArray.new(*newshape)
  asub.setvals(newdata)
  p asub.shape,asub


end
