require 'basicnumarray'
require 'forwarding'
require 'arrayindex'

class NumArray
  extend Forwarding

  def initialize(*shape)
    if shape.length==1 && shape[0].is_a?(Array) then; shape=shape[0]; end
    @shape=shape.dup            # lengths of dimensions
    nd=@shape.length      #   # of dimension
    @ntot=1 ; for i in 0..nd-1 ; @ntot*=@shape[i]; end
    @data=BasicNumArray.new(@ntot)
  end

  alias _clone_ clone

  def clone
    out=self._clone_
    dup_list=['@shape','@data']
    dup_list.each{ |v|
      out.instance_eval("#{v} = #{v}.clone")
    }
    out
  end

  def_forwardings! :@data, :span, :product, :numtype, :max, :min, 
                           :length, :size, :logical?, :negate!
  def_forwardings! :@data, :to_a, :fill
  def_forwardings! :filter, :each, :each_dim
  def_forwardings :@data, :sin, :cos, :tan, :exp, :log, :log10, :sqrt,
                          :abs, :ldexp, :+@, :-@, :**, :negate

  # numeric operators such as "+" and "*" are defined by def_methods below.
  Twoterm_operators=["+","-","*","/"]
  Twoterm_methods=["le","lt","ge","gt","and_","or_","atan2"]


  def NumArray.[](*ary)
    # 1D array generator such as Array[]
    out=self.new(ary.length)
    if ary.is_a?(BasicNumArray) then
      out.instance_eval{@data=ary.dup}
    else
      out.instance_eval{@data=BasicNumArray[*ary]}
    end
    out
  end

  def[](*idx)
    aridx=ArrayIndex.new(@shape,*idx)
    data,shape=aridx.crop(@data)
    out=NumArray.new(shape)
    out._data_(data)
    out
  end

  def[]=(*idx)
    rhs=idx.pop     # idx=idx[0..-2] and rhs=idx[-1]
    aridx=ArrayIndex.new(@shape,*idx)
    aridx.setval(@data,rhs)
    self
  end

  def shape; @shape.dup; end
  def ndims; @shape.length end
  alias rank ndims

  def trim!
    @shape.delete_if{|i| i==1}
    self
  end
  def trim
    out = self.dup
    out.trim!
  end

  def to_basic
    @data.dup
  end

  def coerce(other)
    if other.kind_of?(Numeric)
      o=NumArray.new(*self.shape)
      o.fill(other)
      return o, self
    elsif other.kind_of?(BasicNumArray)
      o=NumArray.new(other.length)
      o._data_(other)
      return o, self
    else
      super
    end
  end

  def inspect
     if self.rank > 1 then
      str = "["
      for n in 0..@ntot-1
	if n-(n/@shape[0])*@shape[0] != @shape[0]-1 then
	  div = ", "
	else
	  div = ",\n"
	end
	ll = @shape[0]*@shape[1]
	if (n-(n/ll)*ll == ll-1) then
	  div = ",\n\n"
	end
	div = '' if n == @ntot-1
	str += @data[n].to_s + div
      end
      str=str+"]"
    else
      @data.inspect
    end
  end

  def reshape!(*shape)
    # reshape the array (without changing the 1D order inside)
    #
    # Example: a[3,2,4] -> a[6,4]
    #
    # USAGE: array.reshape(D1,D2,D3,..)
    #        where D? is the lenth of the ?-th dimension.
    #        D1*D2*..*Dn must be the total lengh of the original array

    ntot=BasicNumArray[*shape].product
    raise("# of elements must be that the original array") if (ntot!=@ntot)
    @shape=shape.dup
    self
  end

  def reshape(*shape)
    out=self.dup
    out.reshape!(*shape)
  end

  def setvals(a)
    # Import a 1D BasicNumAarray as the data entity
    if ! a.is_a?(BasicNumArray) then raise "Not a BasicNumArray"; end
    if a.length != self.length then raise "Invalid length"; end
    @data=a
    self
  end

  ######### PROTECTED METHODS ########
  protected
  def _data_(a)
    #(PROTECTED)
    # Import a 1D BasicNumAarray as the data entity
    # [CAUTION] - validity is NOT checked
    #           - use dup if needed (obj._data_(a.dup))
    @data=a
  end
  def _shape_(a)
    #(PROTECTED)
    # Set a 1D Array as the shape
    # [CAUTION] - validity is NOT checked
    #           - use dup if needed (obj._shape_(a.dup))
    @shape=a
  end

  ######### PRIVATE METHODS ########
  private
  def NumArray.def_methods
    for f in Twoterm_operators
      eval <<-EOS
      def #{f}(other)
	if other.is_a?(Numeric) || other.is_a?(Array) ||
	    other.is_a?(BasicNumArray) then
	  out=self.dup
	  out._data_(@data #{f} other)
	  out
	elsif other.is_a?(NumArray) then
	  out=self.dup
	  out._data_(@data #{f} other.to_basic)
	  out
	else
	  coerce_me, coerce_other = other.coerce(self)
	  coerce_me #{f} coerce_other
	end
      end
      EOS
    end

    for f in Twoterm_methods
      eval <<-EOS
      def #{f}(other)
	if other.is_a?(Numeric) || other.is_a?(Array) ||
	    other.is_a?(BasicNumArray) then
	  out=self.dup
	  out._data_(@data.#{f}(other))
	  out
	elsif other.is_a?(NumArray) then
	  out=self.dup
	  out._data_(@data.#{f}(other.to_basic))
	  out
	else
	  coerce_me, coerce_other = other.coerce(self)
	  coerce_me.#{f}(coerce_other)
	end
      end
      EOS
    end
  end

  def_methods
end

if __FILE__ == $0
  p '== TEST =='
  print "**** create & span ****\n"
  print "a=NumArray.new(4,3,2) ; a.span;\n"
  a=NumArray.new(4,3,2) ; a.span
  print ' a = ';p a
  print " a.type, a.numtype, a.rank =  #{a.type}, #{a.numtype}, #{a.rank}\n"
  print " a.max, a.min, a.length =  #{a.max}, #{a.min}, #{a.length}\n"

  print "\n**** dup (clone) ****\n"
  b=a.dup ; b[0..2,0,0]=999
  print "b=a.dup ; b[0..2,0,0]=999 ;\n"
  print ' a[0..-1,0,0] = '; p a[0..-1,0,0]
  print ' b[0..-1,0,0] = '; p b[0..-1,0,0]

  print "\n**** reshape ****\n"
  print "b=a.reshape(6,4) ;\n" ; b=a.reshape(6,4)
  print " b.shape = " ; p b.shape

  print "\n**** to array ****\n"
  print " a.type, a.to_basic.type, a.to_a.type = ",
        "#{a.type}, #{a.to_basic.type}, #{a.to_a.type}\n"
  print " a.to_basic = "; p a.to_basic

  print "\n**** 1D creattion ****\n"
  print "c=NumArray[-1,0,1,2,3] ;\n"; c=NumArray[-1,0,1,2,3]
  print ' c.rank = '; p c.rank
  print ' c = '; p c

  print "\n**** arithmetic operations ****\n"
  print ' c = '; p c
  print ' c+c =';p c+c
  print ' c-10 =';p c-10
  print ' 10-c =';p 10-c
  print ' -c**2 =';p -c**2

  print "\n**** Math operations ****\n"
  print " c.sin = "; p c.sin
  print " c.exp = "; p c.exp
  print " c.abs = "; p c.abs
  print " c.ldexp(4) = "; p c.ldexp(4)
  print " c.atan2(c**2)/PI = "; p c.atan2(c**2)/PI

  print "\n**** subset extraction ****\n"
  print "b = a[true,0,0] ;\n" ; b = a[true,0,0]
  print "   (note) true is the same as 0..-1\n"
  print ' b = '; p b
  print " b.id =  #{b.id}, a.id =  #{a.id} , (a.id==b.id) = #{a.id==b.id}\n"
  print "d=a[0..2,{0..-1,2},0] ;\n"; d=a[0..2,{0..-1,2},0]
  print " d = ";p d
  print " d.shape = "; p d.shape
  print " d.rank = "; p d.rank
  print "d.trim!;\n" ; d.trim!
  print " d.shape = "; p d.shape
  print " d.rank = "; p d.rank
  print "   (note): original dimensions are kept thru subset extraction.\n",
        "           use trim! (or trim) to trim off the dimensions of length 1\n"
  print "mask=NumArray[true,false,true,false];\n"
  mask=NumArray[true,false,true,false]
  print " a[mask,0,0] = "; p a[mask,0,0]

  print " a[[1,0,2],0,[1,0]] = \n"; p a[[1,0,2],0,[0,1]]

  print "mask=a.lt(3.0)\n"; mask=a.lt(15).and_(a.gt(4))
  print "a[mask] = \n"; p a[mask]
  print "   (note) As shown above, a mask can be applied not only to a \n",
        "          dimension but also to the whole array; the latter yields\n",
        "          a 1D NumArray\n"

  print "\n**** subset extraction (rubber dimension) ****\n"
  print "x = NumArray.new(4,2,3,2) ; x.span;\n"
  x = NumArray.new(4,2,3,2) ; x.span
  ##print " x = "; p x
  print " x[RUB,-1] = " ; p x[RUB,-1]
  print "   (note) RUB is substituted with an appropitate number of true (or 0..-1).\n","          Thus x[RUB,-1] == x[true,true,true,-1] in this case\n"
  print " x[-1,RUB,-1] = " ; p x[-1,RUB,-1]
  print " x[-1,RUB].trim! = " ; p x[0,RUB].trim!
  print " x[SKIP[1],0,RUB] = " ; p x[SKIP[1],0,RUB]
  print "   (note) SKIP[count] is substitued with true for count times.\n",
        "          So the current expresion is equivalent to x[true,0,RUB]\n"
        "          i.e., x[true,0,true,true], the 1st element of the 2nd dim.\n"

  print "\n**** subset substitution ****\n"
  print "a[true,1,true] = 999\n" ; a[true,1,true] = 999
  print " a = "; p a
  print "a[2,true,1] = [1111,2222,3333]\n" ; a[2,true,1] = [1111,2222,3333]
  print " a = "; p a

  print "a.span(0.,0.2);\n" ; a.span(0.0,0.2)
  print "mask=a.gt(3.0)\n"; mask=a.gt(3.0)
  print "a[mask] = 99.0;\n"; a[mask] = 99.0
  print " a = \n" ; p a
  print "a[mask.negate] *= 1.01;\n"; a[mask.negate] *= 1.01
  print " a = \n" ; p a
end


