pism

[fork] customized build of PISM, the parallel ice sheet model (tillflux branch)
git clone git://src.adamsgaard.dk/pism # fast
git clone https://src.adamsgaard.dk/pism.git # slow
Log | Files | Refs | README | LICENSE Back to index

ssa_tao.py (10577B)


      1 # Copyright (C) 2012, 2014, 2015, 2016, 2018 David Maxwell and Constantine Khroulev
      2 #
      3 # This file is part of PISM.
      4 #
      5 # PISM is free software; you can redistribute it and/or modify it under the
      6 # terms of the GNU General Public License as published by the Free Software
      7 # Foundation; either version 3 of the License, or (at your option) any later
      8 # version.
      9 #
     10 # PISM is distributed in the hope that it will be useful, but WITHOUT ANY
     11 # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
     12 # FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
     13 # details.
     14 #
     15 # You should have received a copy of the GNU General Public License
     16 # along with PISM; if not, write to the Free Software
     17 # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
     18 
     19 """Inverse SSA solvers using the TAO library."""
     20 
     21 import PISM
     22 from PISM.util import Bunch
     23 from PISM.logging import logError
     24 from PISM.invert.ssa import InvSSASolver
     25 
     26 import sys
     27 import traceback
     28 
     29 
     30 class InvSSASolver_Tikhonov(InvSSASolver):
     31 
     32     """Inverse SSA solver based on Tikhonov iteration using TAO."""
     33 
     34     # Dictionary converting PISM algorithm names to the corresponding
     35     # TAO algorithms used to implement the Tikhonov minimization.
     36     tao_types = {}
     37 
     38     if (not PISM.imported_from_sphinx) and PISM.PETSc.Sys.getVersion() < (3, 5, 0):
     39         tao_types = {'tikhonov_lmvm': 'tao_lmvm',
     40                      'tikhonov_cg': 'tao_cg',
     41                      'tikhonov_lcl': 'tao_lcl',
     42                      'tikhonov_blmvm': 'tao_blmvm'}
     43     else:
     44         tao_types = {'tikhonov_lmvm': 'lmvm',
     45                      'tikhonov_cg': 'cg',
     46                      'tikhonov_lcl': 'lcl',
     47                      'tikhonov_blmvm': 'blmvm'}
     48 
     49 
     50     def __init__(self, ssarun, method):
     51         """
     52         :param ssarun: The :class:`PISM.invert.ssa.SSAForwardRun` defining the forward problem.
     53         :param method: String describing the actual algorithm to use. Must be a key in :attr:`tao_types`."""
     54 
     55         InvSSASolver.__init__(self, ssarun, method)
     56         self.listeners = []
     57         self.solver = None
     58         self.ip = None
     59         if self.tao_types.get(method) is None:
     60             raise ValueError("Unknown TAO Tikhonov inversion method: %s" % method)
     61 
     62     def addIterationListener(self, listener):
     63         """Add a listener to be called after each iteration.  See :ref:`Listeners`."""
     64         self.listeners.append(listener)
     65 
     66     def addDesignUpdateListener(self, listener):
     67         """Add a listener to be called after each time the design variable is changed."""
     68         self.listeners.append(listener)
     69 
     70     def solveForward(self, zeta, out=None):
     71         r"""Given a parameterized design variable value :math:`\zeta`, solve the SSA.
     72         See :cpp:class:`IP_TaucParam` for a discussion of parameterizations.
     73 
     74         :param zeta: :cpp:class:`IceModelVec` containing :math:`\zeta`.
     75         :param out: optional :cpp:class:`IceModelVec` for storage of the computation result.
     76         :returns: An :cpp:class:`IceModelVec` contianing the computation result.
     77         """
     78         ssa = self.ssarun.ssa
     79 
     80         reason = ssa.linearize_at(zeta)
     81         if reason.failed():
     82             raise PISM.AlgorithmFailureException(reason)
     83         if out is not None:
     84             out.copy_from(ssa.solution())
     85         else:
     86             out = ssa.solution()
     87         return out
     88 
     89     def solveInverse(self, zeta0, u_obs, zeta_inv):
     90         r"""Executes the inversion algorithm.
     91 
     92         :param zeta0: The best `a-priori` guess for the value of the parameterized design variable :math:`\zeta`.
     93         :param u_obs: :cpp:class:`IceModelVec2V` of observed surface velocities.
     94         :param zeta_inv: :cpp:class:`zeta_inv` starting value of :math:`\zeta` for minimization of the Tikhonov functional.
     95         :returns: A :cpp:class:`TerminationReason`.
     96         """
     97         eta = self.config.get_number("inverse.tikhonov.penalty_weight")
     98 
     99         design_var = self.ssarun.designVariable()
    100         if design_var == 'tauc':
    101             if self.method == 'tikhonov_lcl':
    102                 problemClass = PISM.IP_SSATaucTaoTikhonovProblemLCL
    103                 solverClass = PISM.IP_SSATaucTaoTikhonovProblemLCLSolver
    104                 listenerClass = TaucLCLIterationListenerAdaptor
    105             else:
    106                 problemClass = PISM.IP_SSATaucTaoTikhonovProblem
    107                 solverClass = PISM.IP_SSATaucTaoTikhonovSolver
    108                 listenerClass = TaucIterationListenerAdaptor
    109         elif design_var == 'hardav':
    110             if self.method == 'tikhonov_lcl':
    111                 problemClass = PISM.IP_SSAHardavTaoTikhonovProblemLCL
    112                 solverClass = PISM.IP_SSAHardavTaoTikhonovSolverLCL
    113                 listenerClass = HardavLCLIterationListenerAdaptor
    114             else:
    115                 problemClass = PISM.IP_SSAHardavTaoTikhonovProblem
    116                 solverClass = PISM.IP_SSAHardavTaoTikhonovSolver
    117                 listenerClass = HardavIterationListenerAdaptor
    118         else:
    119             raise RuntimeError("Unsupported design variable '%s' for InvSSASolver_Tikhonov. Expected 'tauc' or 'hardness'" % design_var)
    120 
    121         tao_type = self.tao_types[self.method]
    122         (stateFunctional, designFunctional) = PISM.invert.ssa.createTikhonovFunctionals(self.ssarun)
    123 
    124         self.ip = problemClass(self.ssarun.ssa, zeta0, u_obs, eta, stateFunctional, designFunctional)
    125         self.solver = solverClass(self.ssarun.grid.com, tao_type, self.ip)
    126 
    127         max_it = int(self.config.get_number("inverse.max_iterations"))
    128         self.solver.setMaximumIterations(max_it)
    129 
    130         pl = [listenerClass(self, l) for l in self.listeners]
    131 
    132         for l in pl:
    133             self.ip.addListener(l)
    134 
    135         self.ip.setInitialGuess(zeta_inv)
    136 
    137         vecs = self.ssarun.modeldata.vecs
    138         if vecs.has('zeta_fixed_mask'):
    139             self.ssarun.ssa.set_tauc_fixed_locations(vecs.zeta_fixed_mask)
    140 
    141         return self.solver.solve()
    142 
    143     def inverseSolution(self):
    144         """Returns a tuple ``(zeta,u)`` of :cpp:class:`IceModelVec`'s corresponding to the values
    145         of the design and state variables at the end of inversion."""
    146         zeta = self.ip.designSolution()
    147         u = self.ip.stateSolution()
    148         return (zeta, u)
    149 
    150 
    151 class TaucLCLIterationListenerAdaptor(PISM.IP_SSATaucTaoTikhonovProblemLCLListener):
    152 
    153     """Adaptor converting calls to a C++ :cpp:class:`IP_SSATaucTaoTikhonovProblemListener`
    154     on to a standard python-based listener.  Used internally by
    155     :class:`InvSSATaucSolver_Tikhonov`.  I.e. don't make one of these for yourself."""
    156 
    157     def __init__(self, owner, listener):
    158         """:param owner: The :class:`InvSSATaucSolver_Tikhonov` that constructed us
    159            :param listener: The python-based listener.
    160          """
    161         PISM.IP_SSATaucTaoTikhonovProblemLCLListener.__init__(self)
    162         self.owner = owner
    163         self.listener = listener
    164 
    165     def iteration(self, problem, eta, it, objVal, penaltyVal, d, diff_d, grad_d, u, diff_u, grad_u, constraints):
    166         """Called during IP_SSATaucTaoTikhonovProblemLCL iterations.  Gathers together the long list of arguments
    167         into a dictionary and passes it along in standard form to the python listener."""
    168 
    169         data = Bunch(tikhonov_penalty=eta, JDesign=objVal, JState=penaltyVal,
    170                      zeta=d, zeta_step=diff_d, grad_JDesign=grad_d,
    171                      u=u, residual=diff_u, grad_JState=grad_u,
    172                      constraints=constraints)
    173         try:
    174             self.listener(self.owner, it, data)
    175         except Exception:
    176             logError("\nERROR: Exception occured during an inverse solver listener callback:\n\n")
    177             traceback.print_exc(file=sys.stdout)
    178             raise
    179 
    180 
    181 class TaucIterationListenerAdaptor(PISM.IP_SSATaucTaoTikhonovProblemListener):
    182 
    183     """Adaptor converting calls to a C++ :cpp:class:`IP_SSATaucTaoTikhonovProblemListener`
    184     on to a standard python-based listener.  Used internally by
    185     :class:`InvSSATaucSolver_Tikhonov`.  I.e. don't make one of these for yourself."""
    186 
    187     def __init__(self, owner, listener):
    188         """:param owner: The :class:`InvSSATaucSolver_Tikhonov` that constructed us
    189            :param listener: The python-based listener.
    190          """
    191         PISM.IP_SSATaucTaoTikhonovProblemListener.__init__(self)
    192         self.owner = owner
    193         self.listener = listener
    194 
    195     def iteration(self, problem, eta, it, objVal, penaltyVal, d, diff_d, grad_d, u, diff_u, grad_u, grad):
    196         """Called during IP_SSATaucTaoTikhonovProblem iterations.  Gathers together the long list of arguments
    197         into a dictionary and passes it along in a standard form to the python listener."""
    198         data = Bunch(tikhonov_penalty=eta, JDesign=objVal, JState=penaltyVal,
    199                      zeta=d, zeta_step=diff_d, grad_JDesign=grad_d,
    200                      u=u, residual=diff_u, grad_JState=grad_u, grad_JTikhonov=grad)
    201         try:
    202             self.listener(self.owner, it, data)
    203         except Exception:
    204             logError("\nERROR: Exception occured during an inverse solver listener callback:\n\n")
    205             traceback.print_exc(file=sys.stdout)
    206             raise
    207 
    208 
    209 class HardavIterationListenerAdaptor(PISM.IP_SSAHardavTaoTikhonovProblemListener):
    210 
    211     """Adaptor converting calls to a C++ :cpp:class:`IP_SSATaucTaoTikhonovProblemListener`
    212     on to a standard python-based listener.  Used internally by
    213     :class:`InvSSATaucSolver_Tikhonov`.  I.e. don't make one of these for yourself."""
    214 
    215     def __init__(self, owner, listener):
    216         """:param owner: The :class:`InvSSATaucSolver_Tikhonov` that constructed us
    217            :param listener: The python-based listener.
    218          """
    219         PISM.IP_SSAHardavTaoTikhonovProblemListener.__init__(self)
    220         self.owner = owner
    221         self.listener = listener
    222 
    223     def iteration(self, problem, eta, it, objVal, penaltyVal, d, diff_d, grad_d, u, diff_u, grad_u, grad):
    224         """Called during IP_SSATaucTaoTikhonovProblem iterations.  Gathers together the long list of arguments
    225         into a dictionary and passes it along in a standard form to the python listener."""
    226         data = Bunch(tikhonov_penalty=eta, JDesign=objVal, JState=penaltyVal,
    227                      zeta=d, zeta_step=diff_d, grad_JDesign=grad_d,
    228                      u=u, residual=diff_u, grad_JState=grad_u, grad_JTikhonov=grad)
    229         try:
    230             self.listener(self.owner, it, data)
    231         except Exception:
    232             logError("\nERROR: Exception occured during an inverse solver listener callback:\n\n")
    233             traceback.print_exc(file=sys.stdout)
    234             raise