#! /usr/bin/env python
#*******************************************************************************
# ALMA - Atacama Large Millimiter Array
# (c) Associated Universities Inc., 2009 
# 
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
# 
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
# 
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307  USA
#
# "@(#) $Id: ObsCalBeaconMapping.py 247921 2017-08-08 15:07:45Z ahirota $"

#
# forcing global imports is due to an OSS problem
#
global copy
import copy

global CCL
import CCL.Global

global Control
import Control

global ControlExceptionsImpl
import ControlExceptionsImpl

global Observation
import Observation.DelayCalTarget
import Observation.SSRTuning
import Observation.ObsCalBase
import Observation.ObsTarget


global BeaconMapTarget
class BeaconMapTarget(Observation.ObsTarget.ObsTarget):
    def __init__(self, SubscanFieldSource=None, SpectralSpec=None,
                 SubscanDuration=30,
                 schedBlockRef=None,
                 mapLength="10 arcmin",
                 mapStep="0.4 arcmin"):
        from PyDataModelEnumeration import PyScanIntent
        import CCL.SubscanList
        Observation.ObsTarget.ObsTarget.__init__(self,
                                                 SubscanFieldSource,
                                                 SpectralSpec,
                                                 PyScanIntent.OBSERVE_TARGET,
                                                 schedBlockRef)
        self._subscanList = CCL.SubscanList.SubscanList()
        self.setSubscanDuration(SubscanDuration)
        self._pattern = CCL.APDMSchedBlock.RectanglePattern()
        self.setMapPattern(mapLength, mapStep)

    def setMapPattern(self, mapLength, mapStep):
        import CCL.APDMSchedBlock
        import six
        rp = CCL.APDMSchedBlock.RectanglePattern()
        rp.patternCenterCoordinates.type = six.u("RELATIVE")
        rp.patternCenterCoordinates.system = six.u("horizon")
        rp.longitudeLength.set(mapLength)
        rp.latitudeLength.set(mapLength)
        rp.orthogonalStep.set(mapStep)
        rp.uniDirectionalScan = False
        self._pattern = rp

        # if self._IntegrationTime is None:
        #     subDur = self.getSubscanDuration()
        #     ssl = rp.generateSubscanList(subDur)
        #     nSubscan = len(ssl)
        #     self._IntegrationTime = subDur * nSubscan

    def _populateSubscanList(self):
        from PyDataModelEnumeration import PySubscanIntent
        from PyDataModelEnumeration import PyCalibrationDevice
        self._subscanList.clear()
        ssl = self._pattern.generateSubscanList(self.getSubscanDuration())
        if len(ssl) > 250:
            raise Exception("Too many scanning rows")
        focusOffsets = [0., 0., 0.]
        for sourceOffsetSpec in ssl:
            offsetSpec = sourceOffsetSpec.pointingOffset
            attenuatorSettingName = self.getAttenuatorSettingName()
            optimizeAttenuators = self.getAttenuatorOptimization()
            self._subscanList.addSubscan(
                self._source,
                self._spectralSpec,
                self._subscanDuration,
                PointingOffsetSpec=sourceOffsetSpec.pointingOffset,
                DelayOffsetSpec=None,
                SubreflectorOffset=focusOffsets,
                CalibrationDevice=self._observeWidget,
                SubscanIntent=PySubscanIntent.ON_SOURCE,
                attenuatorSettingName=attenuatorSettingName,
                optimizeAttenuators=optimizeAttenuators,
            )
            self._addMixerModeToLastSubscanIfRequired()


global ObsCalBeaconSupportBase
class ObsCalBeaconSupportBase(Observation.ObsCalBase.ObsCalBase):

    def __init__(self):
        self._addOptions()
        Observation.ObsCalBase.ObsCalBase.__init__(self)

    def _addOptions(self):
        self.options.extend([
            Observation.ObsCalBase.scriptOption("E", float, -140.79),
            Observation.ObsCalBase.scriptOption("N", float, -5189.69),
            Observation.ObsCalBase.scriptOption("U", float, 321.36 + 7. - 0.15),
            Observation.ObsCalBase.scriptOption("pointingOffsetFile", str, ""),
            Observation.ObsCalBase.scriptOption("delayOffsetFile", str, ""),
            Observation.ObsCalBase.scriptOption("focusOffsetFile", str, ""),
            Observation.ObsCalBase.scriptOption("refAnt", str, ""),
        ])

    def parseCommonOptions(self):
        self.refAnt = self.args.refAnt
        self.bPosE = self.args.E
        self.bPosN = self.args.N
        self.bPosU = self.args.U
        self.pointingOffsetFile = self.args.pointingOffsetFile
        self.delayOffsetFile = self.args.delayOffsetFile
        self.focusOffsetFile = self.args.focusOffsetFile
        self.elLimit = "2 deg"

    @staticmethod
    def catchScanException(func):
        def dFunc(self, *args, **kwargs):
            try:
                ret = func(self, *args, **kwargs)
                return ret
            except BaseException as ex:
                self.logError("Error executing a scan")
                self.closeExecution(ex)
                raise
        return dFunc

    def setTelCalParams(self):
        from Observation.Global  import simulatedArray
        self.logInfo("Setting TelCal parameters pointingFitWidth=True, simpleGaussianFit=False")
        if simulatedArray():
            return
        tcParameters = self.getTelCalParams()
        # It appears allowing TelCal to fit width works better than
        # not to fix the width. Probably, it is because the Beacon
        # signal is not a point source.
        tcParameters.setCalibParameter('pointingFitWidth', True)
        tcParameters.setCalibParameter('simpleGaussianFit',False)


class ObsCalBeaconMapping(ObsCalBeaconSupportBase):

    options = [
        Observation.ObsCalBase.scriptOption("band", int, 3),
        Observation.ObsCalBase.scriptOption("dumpDuration", float, 0.192),
        Observation.ObsCalBase.scriptOption("channelAverageDuration", float, 0.384),
        Observation.ObsCalBase.scriptOption("subscanDuration", float, 9.216),
        # ACA requires more than about 8.5 sec in FDM mode for 16-ant array
        Observation.ObsCalBase.scriptOption("integrationDuration", float, 0.384),
        Observation.ObsCalBase.scriptOption("tpIntegrationDuration", float, 0.016),
        Observation.ObsCalBase.scriptOption("bbNames", str, ""),
        Observation.ObsCalBase.scriptOption("frequency", float, -1.),
        Observation.ObsCalBase.scriptOption("bbFreqs", str, ""),
        Observation.ObsCalBase.scriptOption("corrMode", str, "TDM"),
        Observation.ObsCalBase.scriptOption("mapLength", str, "20 arcmin"),
        Observation.ObsCalBase.scriptOption("mapStep", str, "0.4 arcmin"),
        Observation.ObsCalBase.scriptOption("doPointing", bool, False),
        Observation.ObsCalBase.scriptOption("doFocus", bool, False),
        Observation.ObsCalBase.scriptOption("noMap", bool, False),
        Observation.ObsCalBase.scriptOption("doATM", bool, False),
        Observation.ObsCalBase.scriptOption("useSQLDForATMProcessing", bool, False),
        # 2019/11/17 add cross-correlation option
        Observation.ObsCalBase.scriptOption("crossCorr", bool, False),
    ]

    def parseOptions(self):
        self.band                    = self.args.band
        self.dumpDuration            = self.args.dumpDuration
        self.channelAverageDuration  = self.args.channelAverageDuration
        self.subscanDuration         = self.args.subscanDuration
        self.integrationDuration     = self.args.integrationDuration
        self.tpIntegrationDuration   = self.args.tpIntegrationDuration

        bbNameStr                    = self.args.bbNames
        self.bbNames = None
        if bbNameStr is not None and bbNameStr != "":
            self.bbNames = []
            for s in bbNameStr.split(','):
                self.bbNames.append(s)

        ObsCalBeaconSupportBase.parseCommonOptions(self)

    def generateTunings(self):
        sameBBFreqs = False
        frequency = Observation.SSRTuning.bandFreqs_delayMeasurement[self.band]
        SBPref = None

        if self.args.frequency > 0:
            frequency_ = self.args.frequency
            sameBBFreqs = True
        else:
            frequency_ = frequency

        if self.args.bbFreqs != "":
            bbFreqs = [float(token) for token in self.args.bbFreqs.split(",")]
        else:
            bbFreqs = None

        bwd = 16 if self.args.corrMode == "FDM" else None

        corrType = self._array.getCorrelatorType()
        self._spectralSpec = self._tuningHelper.GenerateSpectralSpec(
            band=self.band,
            intent="interferometry_continuum",
            frequency=frequency_,
            bbNames=self.bbNames,
            SBPref=SBPref,
            sameBBFreqs=sameBBFreqs,
            bbFreqs=bbFreqs,
            corrType=corrType,
            corrMode=self.args.corrMode,
            bwd=bwd,
            dualMode=True,
            dump=self.dumpDuration,
            channelAverage=self.channelAverageDuration,
            integration=self.integrationDuration
        )
        if self.args.corrMode == "TDM":
            # Updated: 2019/10/15
            chanAverNumChans = 4
            chanAverStartChan = (128 - 4) / 2
            corrConfig = self._spectralSpec.BLCorrelatorConfiguration
            for bbc in corrConfig.BLBaseBandConfig:
                for spw in bbc.BLSpectralWindow:
                    chavg = spw.ChannelAverageRegion[0]
                    chavg.numberChannels = chanAverNumChans
                    chavg.startChannel = chanAverStartChan
        elif self.args.corrMode == "FDM":
            # FDM - will be removed anyway
            # chanAverNumChans = 512
            # chanAverStartChan = (3840 - 512) / 2
            # chanAverNumChans = 40
            # chanAverStartChan = (480 - 40) / 2
            chanAverNumChans = 8
            spAvgFactor = 8

            # chanAverStartChan = 3840 / spAvgFactor - chanAverNumChans / 2
            factor = 1
            corrConfig = self._spectralSpec.BLCorrelatorConfiguration
            for bbc in corrConfig.BLBaseBandConfig:
                for spw in bbc.BLSpectralWindow:
                    spw.spectralAveragingFactor = spAvgFactor // factor
                    spw.effectiveNumberOfChannels //= factor
                    nCh = spw.effectiveNumberOfChannels
                    self.logInfo("%s %s" % (nCh, type(nCh)))
                    chavg = spw.ChannelAverageRegion[0]
                    chanAverStartChan = (nCh - chanAverNumChans) / 2
                    chavg.startChannel = chanAverStartChan // factor
                    chavg.numberChannels = chanAverNumChans // factor

    def setReferenceAntenna(self, refAnt=None):
        import Observation.Global
        if refAnt is None:
            refAnt = self.refAnt
        if Observation.Global.simulatedArray():
            return
        tcParameters = self.getTelCalParams()
        self.logInfo("Set %s as the reference antenna" % (refAnt))
        tcParameters.parameterTuning.setCalibParameterAsString(
            'refantenna', refAnt, self.arrayName, 1)

    def doAtmCal(self):
        import Observation.AtmCalTarget
        import CCL.SourceOffset
        import math
        if not self.args.doATM:
            return
        src = self._srcPointFocus
        ss = self._spectralSpec
        try:
            # ss = Observation.SSRTuning.generateAtmSpectralSpec(ss)
            atm = Observation.AtmCalTarget.AtmCalTarget(src, ss,
                                                        doHotLoad=True)
            atm.setOnlineProcessing(True)
            if self.args.useSQLDForATMProcessing:
                atm.setDataOrigin("TOTAL_POWER")
            else:
                atm.setDataOrigin('FULL_RESOLUTION_AUTO')
            atm.setDoZero(False)
            atm.setSubscanDuration(5.76)
            atm.setIntegrationTime(1.5)
            atm.setWVRCalReduction(True)
            atm.setApplyWVR(False)
            atm._referenceSource = atm._source
            # TODO change the sign per each antenna to avoid crossing Az=+180/-180
            atm._referenceOffset = CCL.SourceOffset.stroke(math.radians(600 / 3600.), 0, 0, 0, Control.HORIZON)
            self.logInfo('Executing AtmCal on ' + str(src.sourceName) + '...')
            atm.execute(self._obsmode)
            self.logInfo('Completed AtmCal on ' + str(src.sourceName))
        except BaseException as ex:
            import traceback
            self.logError(traceback.format_exc())
            msg = "Error executing AtmCal on source '%s'" % str(src.sourceName)
            self.logException(msg, ex)
            self.closeExecution(ex)
            raise ex

    @ObsCalBeaconSupportBase.catchScanException
    def doMap(self):
        if self.args.noMap:
            return
        src = self._srcPointFocus
        ss = self._spectralSpec
        target = BeaconMapTarget(src, ss,
                                 SubscanDuration=self.subscanDuration,
                                 mapLength=self.args.mapLength,
                                 mapStep=self.args.mapStep)
        self.logInfo('Executing map scan on ' + src.sourceName + '...')
        target.execute(self._obsmode)
        self.logInfo('Completed map scan on ' + src.sourceName)

    @ObsCalBeaconSupportBase.catchScanException
    def doPointing(self):
        if not self.args.doPointing:
            return
        from Observation.PointingCalTarget import PointingCalTarget
        from PyDataModelEnumeration import PyCalDataOrigin
        from PyDataModelEnumeration import PyPointingMethod

        src = self._srcPointFocus
        ss = self._spectralSpec
        target = PointingCalTarget(src, ss)
        target.setSubscanDuration(5.76 * 2)
        if self.args.crossCorr:
            target.setDataOrigin(PyCalDataOrigin.CHANNEL_AVERAGE_CROSS)
        else:
            target.setDataOrigin(PyCalDataOrigin.CHANNEL_AVERAGE_AUTO)
        target.setPointingMethod(PyPointingMethod.CROSS)
        # target.setPointingMethod(PyPointingMethod.FIVE_POINT)
        # target.setExcursion("120arcsec")

        # # 2019/06/20
        # target.setExcursion("150arcsec")
        # 2019/11/17
        # Now the required ammount of pointing correction is less than 20 arcseconds
        target.setExcursion("75arcsec")
        target._verbose = True

        self.logInfo('Executing PointingCal on ' + src.sourceName + '...')
        target.execute(self._obsmode)
        self.logInfo('Completed PointingCal on ' + src.sourceName)

        result = target.checkResult(self._array)
        self.logInfo("Result is: %s" % str(result))

        if len(result) > 0:
            target.applyResult(self._obsmode, result)
        else:
            if not "OSS" in self._array._arrayName:
                raise Exception("No pointing results!")


    @ObsCalBeaconSupportBase.catchScanException
    def doFocus(self):
        if not self.args.doFocus:
            return
        from Observation.FocusCalTarget import FocusCalTarget
        from PyDataModelEnumeration import PyCalDataOrigin

        src = self._srcPointFocus
        ss = self._spectralSpec
        # dataOrigin = PyCalDataOrigin.CHANNEL_AVERAGE_CROSS
        dataOrigin = PyCalDataOrigin.CHANNEL_AVERAGE_AUTO
        target = FocusCalTarget(src, ss,
                                SubscanDuration=5.76,
                                DataOrigin=dataOrigin,
                                OneWay=False,
                                NumPositions=7)
        # TODO: Set excursion, just in case
        target.setExcursion(1e-3)
        # target.setSubscanDuration(5.76)
        # target.setDataOrigin('CHANNEL_AVERAGE_CROSS')

        self.logInfo('Executing FocusCal on ' + src.sourceName + '...')
        target.execute(self._obsmode)
        self.logInfo('Completed FocusCal on ' + src.sourceName)

        try:
            result = target.checkResult(self._array)
            self.logInfo("Result is: %s" % str(result))
            # if len(result) > 0:
            #     target.applyResult(self._obsmode, result)
            # else:
            #     if not "OSS" in self._array._arrayName:
            #         raise Exception("No pointing results!")
        except:
            import traceback
            self.logWarning(traceback.format_exc())

    def findPointFocusSource(self):
        from Observation.ArtificialSourceHelper import ArtificialSourceHelper
        bPosENU = [self.bPosE, self.bPosN, self.bPosU]
        builder = ArtificialSourceHelper(bPosENU)
        builder.readOffsetsFromFile("pointing", self.pointingOffsetFile)
        builder.readOffsetsFromFile("delay", self.delayOffsetFile)
        # builder.readOffsetsFromFile("focus", self.focusOffsetFile)
        if self.refAnt == "":
            self.refAnt = builder.selectReferenceAntenna()
        src = builder.createFieldSource(referenceAntenna=self.refAnt)
        self.logInfo("beacon FieldSource: %s" % str(src.toDOM().toxml()))
        self._srcPointFocus = src


obs = ObsCalBeaconMapping()
obs.parseOptions()
obs.checkAntennas()
obs.startPrepareForExecution()
try:
    obs.generateTunings()
    obs.findPointFocusSource()
except BaseException as ex:
    import traceback
    obs.logError(str(traceback.format_exc()))
    obs.logException("Error in methods run during execution/obsmode startup", ex)
    obs.completePrepareForExecution()
    obs.closeExecution(ex)
    raise ex
obs.completePrepareForExecution()
obs.setTelCalParams()
obs.logInfo("Executing a map scan...")
obs.setReferenceAntenna()
obs.doPointing()
obs.doPointing()
if obs.args.doFocus:
    obs.doFocus()
    obs.doPointing()
obs.doAtmCal()
obs.doMap()
obs.closeExecution()
