# coding: utf-8
=begin
= Lomb-Scargle periodgram
=end

require "numru/gphys"

module NumRu
  module GAnalysis
    module LombScargle
      module_function

      # = Lomb-Scargle periodgram
      # 
      # ARGUMENTS
      # * y [NArray or NArrayMiss] : data
      # * dim [Integer] : the dimension to apply LS
      # * t [1D NArray] : "time" (the coordinate values along dim)
      # * f [Numric or 1D NArray] : frequency. If 1D NArray, the frequency
      #   dimension is added to the end of outputs
      # 
      # OUTPUTS
      # * a : coef of cos
      # * b : coef of sin  (sqrt(a**2+b**2) is the amplitude)
      # * reconst : reconstructed y (fitting result)
      # * cos [NArray or NArrayMiss]: cos(2*pi*f*t + ph)
      #   (useful if you want to reconstruct for a part of f)
      # * sin [NArray or NArrayMiss]: sin(2*pi*f*t + ph)
      #   (useful if you want to reconstruct for a part of f)
      # * cc [NArray or NArrayMiss]: normarization factor sum(cos**2)
      #   (needed to derive power spectrum)
      # * ss [NArray or NArrayMiss]: normarization factor sum(sin**2)
      #   (needed to derive power spectrum)
      # * ph [NArray or NArrayMiss]: phase (for each f and,
      #   if with miss, for each grid point)
      def lomb_scargle(y, dim, t, f)
        rank = y.rank
        if y.respond_to?(:get_mask)
          mask = y.get_mask
        else
          mask = nil
        end
        ph, ot_ph = derive_phase(t,f,dim,rank,mask)
        cos = Misc::EMath.cos( ot_ph )
        sin = Misc::EMath.sin( ot_ph )
        cc = (cos*cos).sum(dim)
        ss = (sin*sin).sum(dim)
        #p "%%%", (cos*sin).sum(dim), cc, ss
        if !f.is_a?(Numeric)
          y = y.newdim(-1)
        end
        a = (y*cos).sum(dim) / cc
        b = (y*sin).sum(dim) / ss
        reconst = (a.newdim(-2)*cos).sum(-1) + (b.newdim(-2)*sin).sum(-1)
        [a,b,reconst,cos,sin,cc,ss,ph]
      end

      def derive_phase(t,f,dim,rank,mask=nil)
        multiple_f = !f.is_a?(Numeric)
        if multiple_f
          o2t = 4*Math::PI * f.newdim(0) * t.newdim(-1)  # 2 omega t
          mask = mask.newdim(-1) if mask
          nf = f.length
        else
          o2t = 4*Math::PI * f * t  # 2 omega t
        end
        with_miss = (mask && mask.count_false > 0)
        if dim >= 1
          o2t = o2t.newdim!( *([0]*dim) )
        end
        if dim < rank-1
          o2t = o2t.newdim!( *([dim+1]*(rank-1-dim)) )
        end
        if with_miss
          # sampling is not common for all grid points
          expander = NArray.float( *mask.shape )
          o2t += expander    # duplicates o2t for all grid points
          if multiple_f
            expander = NArray.byte(*([1]*rank+[nf])) # duplicates mask for all f
            mask += expander
          end
          o2t = NArrayMiss.to_nam(o2t,mask)
        end
        ph = 0.5* Misc::EMath::atan2( Misc::EMath.sin(o2t).sum(dim),
                                     -Misc::EMath.cos(o2t).sum(dim)).newdim(dim)
        ot_ph = o2t*0.5 + ph  # omega*t + ph
        [ph, ot_ph]
      end
    end
  end

  class GPhys
    # = Lomb-Scargle periodgram
    # 
    # The direct outputs are the amplitude of each wave components,
    # ga as the coefficties for cos and gb as the coefficties for sin.
    # 
    # LS power spectrum can be derived as
    #   pw = ga**2 * cc/(ntlen*df) + gb**2 * ss/(ntlen*df)
    #      = ga**2 * cc/(2*fmax) + gb**2 * ss/(2*fmax)
    # Here, ntlen=2*f.length (equivalent data length from frequency sampling).
    #
    # ARGUMENTS
    # * dim [Integer or String] : dimension to apply fitting
    # * fmax [Numeric]: max frequency (sampling frequencies [df, 2*df,...,fmax])
    #   fmax should be a multiple of df
    # * df [Numeric/nil]: frequency increment (If nil, set to fmax: single freq)
    # * f_long_name, f_name : long_name and name of the "frequency" axis.
    #   You may want to specify them if the dimension is not time.
    # 
    # OUTPUTS
    # * ga [GPhys]: coef of cos
    # * gb [GPhys]: coef of sin  (sqrt(a**2+b**2) is the amplitude)
    # * greconst [GPhys]: reconstructed y (fitting result)
    # * cos [NArray or NArrayMiss]: cos(2*pi*f*t + ph)
    #   (useful if you want to reconstruct for a part of f)
    # * sin [NArray or NArrayMiss]: sin(2*pi*f*t + ph)
    #   (useful if you want to reconstruct for a part of f)
    # * cc [NArray or NArrayMiss]: normarization factor sum(cos**2)
    #   (needed to derive power spectrum)
    # * ss [NArray or NArrayMiss]: normarization factor sum(sin**2)
    #   (needed to derive power spectrum)
    # * ph [NArray or NArrayMiss]: phase (for each f and,
    #   if with miss, for each grid point)
    # * f [1D NArray]: frequencies derived from df and fmax

    def lomb_scargle(dim, fmax, df=nil, f_long_name="frequency", f_name="f")

      #< prep >
      df = fmax if df.nil?
      dim = dim_index( dim )
      ct = coord(dim)
      t = ct.val
      nf = (fmax/df).round
      f = df * (NArray.float(nf).indgen!+1.0)
      y = val

      #< do it >
      a, b, reconst, cos, sin, cc, ss, ph =
                   GAnalysis::LombScargle::lomb_scargle(y, dim, t, f)

      #< to gphys >
      tun = Units.new( coord(dim).units.to_s.sub(/ *since.*$/,'') )
      fun = (tun**(-1)).to_s
      fax = Axis.new.set_pos(
            VArray.new(f, {"long_name"=>f_long_name,"units"=>fun}, f_name) )
      oaxes = (0...rank).map{|d| axis(d)}
      caxes = oaxes.dup
      caxes[dim] = fax
      cgrid = Grid.new(*caxes)  # grid of fitting (~Fourier) coefficients
      un = units.to_s
      ga = GPhys.new( cgrid,
                      VArray.new(a,{"long_name"=>"coef a","units"=>un},"a") )
      gb = GPhys.new( cgrid,
                      VArray.new(b,{"long_name"=>"coef b","units"=>un},"b") )
      greconst = GPhys.new( grid, VArray.new(reconst,data,name) )
      [ga, gb, greconst, cos, sin, cc, ss, ph, f]
    end
  end

end

################################################
##### test part ######
if $0 == __FILE__
  require "numru/ggraph"
  include NumRu
  include Misc::EMath

  cx = VArray.new(NArray.float(2).indgen!, {"units"=>"1","long_name"=>"x"}, "x")

  ## unequal spacing : fitting
  #t = NArray.to_na( [0.0, 0.1, 1, 1.5, 3, 3.3, 3.8, 6, 6.3, 7, 7,5, 8, 9] )

  ## equal spacing with missing
  #t = NArray.to_na( [0.0, 2, 3, 4, 5, 6, 7, 8, 9] )

  # equal spacing: equal to DFT if f is multiples of 1/10
  t = NArray.float(10).indgen!

  ct = VArray.new(t, {"units"=>"days","long_name"=>"time"},"t")
  grid = Grid.new( Axis.new.set_pos(cx), Axis.new.set_pos(ct) )
  nt = t.length

  v = NArrayMiss.float(2,nt)
  f = 1.0/10.0  # base frequency
  o = 2*PI*f
  v[0,true] = sin(o*t) + 0.4*sin(2*o*t) + 0.2*cos(2*o*t) + 0.3*cos(3*o*t)
  v[1,true] = cos(o*t) + 0.2*sin(2*o*t) + 0.2*cos(2*o*t) + 2*sin(3*o*t)

  v.invalidation(1,1)  # then, v[1,true] fitting will be different from DFT
  
  y = VArray.new(v, {"units"=>"m/s","long_name"=>"y"}, "y")
  gp = GPhys.new(grid, y)

  df = f
  fmax = 4*df
  #df = f / 2
  #fmax = 8*df
  ga, gb, greconst, cos, sin, ph = gp.lomb_scargle("t",fmax,df)
  p "**", ga, ph
  ampsp = ga**2 + gb**2
  amp = Misc::EMath.sqrt(ampsp.val)
  p "&&", amp, ampsp.val.sum(-1)/2, ((greconst.val - v)**2).mean(-1)
end
