ssa_tao.py (10577B)
1 # Copyright (C) 2012, 2014, 2015, 2016, 2018 David Maxwell and Constantine Khroulev 2 # 3 # This file is part of PISM. 4 # 5 # PISM is free software; you can redistribute it and/or modify it under the 6 # terms of the GNU General Public License as published by the Free Software 7 # Foundation; either version 3 of the License, or (at your option) any later 8 # version. 9 # 10 # PISM is distributed in the hope that it will be useful, but WITHOUT ANY 11 # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS 12 # FOR A PARTICULAR PURPOSE. See the GNU General Public License for more 13 # details. 14 # 15 # You should have received a copy of the GNU General Public License 16 # along with PISM; if not, write to the Free Software 17 # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 18 19 """Inverse SSA solvers using the TAO library.""" 20 21 import PISM 22 from PISM.util import Bunch 23 from PISM.logging import logError 24 from PISM.invert.ssa import InvSSASolver 25 26 import sys 27 import traceback 28 29 30 class InvSSASolver_Tikhonov(InvSSASolver): 31 32 """Inverse SSA solver based on Tikhonov iteration using TAO.""" 33 34 # Dictionary converting PISM algorithm names to the corresponding 35 # TAO algorithms used to implement the Tikhonov minimization. 36 tao_types = {} 37 38 if (not PISM.imported_from_sphinx) and PISM.PETSc.Sys.getVersion() < (3, 5, 0): 39 tao_types = {'tikhonov_lmvm': 'tao_lmvm', 40 'tikhonov_cg': 'tao_cg', 41 'tikhonov_lcl': 'tao_lcl', 42 'tikhonov_blmvm': 'tao_blmvm'} 43 else: 44 tao_types = {'tikhonov_lmvm': 'lmvm', 45 'tikhonov_cg': 'cg', 46 'tikhonov_lcl': 'lcl', 47 'tikhonov_blmvm': 'blmvm'} 48 49 50 def __init__(self, ssarun, method): 51 """ 52 :param ssarun: The :class:`PISM.invert.ssa.SSAForwardRun` defining the forward problem. 53 :param method: String describing the actual algorithm to use. Must be a key in :attr:`tao_types`.""" 54 55 InvSSASolver.__init__(self, ssarun, method) 56 self.listeners = [] 57 self.solver = None 58 self.ip = None 59 if self.tao_types.get(method) is None: 60 raise ValueError("Unknown TAO Tikhonov inversion method: %s" % method) 61 62 def addIterationListener(self, listener): 63 """Add a listener to be called after each iteration. See :ref:`Listeners`.""" 64 self.listeners.append(listener) 65 66 def addDesignUpdateListener(self, listener): 67 """Add a listener to be called after each time the design variable is changed.""" 68 self.listeners.append(listener) 69 70 def solveForward(self, zeta, out=None): 71 r"""Given a parameterized design variable value :math:`\zeta`, solve the SSA. 72 See :cpp:class:`IP_TaucParam` for a discussion of parameterizations. 73 74 :param zeta: :cpp:class:`IceModelVec` containing :math:`\zeta`. 75 :param out: optional :cpp:class:`IceModelVec` for storage of the computation result. 76 :returns: An :cpp:class:`IceModelVec` contianing the computation result. 77 """ 78 ssa = self.ssarun.ssa 79 80 reason = ssa.linearize_at(zeta) 81 if reason.failed(): 82 raise PISM.AlgorithmFailureException(reason) 83 if out is not None: 84 out.copy_from(ssa.solution()) 85 else: 86 out = ssa.solution() 87 return out 88 89 def solveInverse(self, zeta0, u_obs, zeta_inv): 90 r"""Executes the inversion algorithm. 91 92 :param zeta0: The best `a-priori` guess for the value of the parameterized design variable :math:`\zeta`. 93 :param u_obs: :cpp:class:`IceModelVec2V` of observed surface velocities. 94 :param zeta_inv: :cpp:class:`zeta_inv` starting value of :math:`\zeta` for minimization of the Tikhonov functional. 95 :returns: A :cpp:class:`TerminationReason`. 96 """ 97 eta = self.config.get_number("inverse.tikhonov.penalty_weight") 98 99 design_var = self.ssarun.designVariable() 100 if design_var == 'tauc': 101 if self.method == 'tikhonov_lcl': 102 problemClass = PISM.IP_SSATaucTaoTikhonovProblemLCL 103 solverClass = PISM.IP_SSATaucTaoTikhonovProblemLCLSolver 104 listenerClass = TaucLCLIterationListenerAdaptor 105 else: 106 problemClass = PISM.IP_SSATaucTaoTikhonovProblem 107 solverClass = PISM.IP_SSATaucTaoTikhonovSolver 108 listenerClass = TaucIterationListenerAdaptor 109 elif design_var == 'hardav': 110 if self.method == 'tikhonov_lcl': 111 problemClass = PISM.IP_SSAHardavTaoTikhonovProblemLCL 112 solverClass = PISM.IP_SSAHardavTaoTikhonovSolverLCL 113 listenerClass = HardavLCLIterationListenerAdaptor 114 else: 115 problemClass = PISM.IP_SSAHardavTaoTikhonovProblem 116 solverClass = PISM.IP_SSAHardavTaoTikhonovSolver 117 listenerClass = HardavIterationListenerAdaptor 118 else: 119 raise RuntimeError("Unsupported design variable '%s' for InvSSASolver_Tikhonov. Expected 'tauc' or 'hardness'" % design_var) 120 121 tao_type = self.tao_types[self.method] 122 (stateFunctional, designFunctional) = PISM.invert.ssa.createTikhonovFunctionals(self.ssarun) 123 124 self.ip = problemClass(self.ssarun.ssa, zeta0, u_obs, eta, stateFunctional, designFunctional) 125 self.solver = solverClass(self.ssarun.grid.com, tao_type, self.ip) 126 127 max_it = int(self.config.get_number("inverse.max_iterations")) 128 self.solver.setMaximumIterations(max_it) 129 130 pl = [listenerClass(self, l) for l in self.listeners] 131 132 for l in pl: 133 self.ip.addListener(l) 134 135 self.ip.setInitialGuess(zeta_inv) 136 137 vecs = self.ssarun.modeldata.vecs 138 if vecs.has('zeta_fixed_mask'): 139 self.ssarun.ssa.set_tauc_fixed_locations(vecs.zeta_fixed_mask) 140 141 return self.solver.solve() 142 143 def inverseSolution(self): 144 """Returns a tuple ``(zeta,u)`` of :cpp:class:`IceModelVec`'s corresponding to the values 145 of the design and state variables at the end of inversion.""" 146 zeta = self.ip.designSolution() 147 u = self.ip.stateSolution() 148 return (zeta, u) 149 150 151 class TaucLCLIterationListenerAdaptor(PISM.IP_SSATaucTaoTikhonovProblemLCLListener): 152 153 """Adaptor converting calls to a C++ :cpp:class:`IP_SSATaucTaoTikhonovProblemListener` 154 on to a standard python-based listener. Used internally by 155 :class:`InvSSATaucSolver_Tikhonov`. I.e. don't make one of these for yourself.""" 156 157 def __init__(self, owner, listener): 158 """:param owner: The :class:`InvSSATaucSolver_Tikhonov` that constructed us 159 :param listener: The python-based listener. 160 """ 161 PISM.IP_SSATaucTaoTikhonovProblemLCLListener.__init__(self) 162 self.owner = owner 163 self.listener = listener 164 165 def iteration(self, problem, eta, it, objVal, penaltyVal, d, diff_d, grad_d, u, diff_u, grad_u, constraints): 166 """Called during IP_SSATaucTaoTikhonovProblemLCL iterations. Gathers together the long list of arguments 167 into a dictionary and passes it along in standard form to the python listener.""" 168 169 data = Bunch(tikhonov_penalty=eta, JDesign=objVal, JState=penaltyVal, 170 zeta=d, zeta_step=diff_d, grad_JDesign=grad_d, 171 u=u, residual=diff_u, grad_JState=grad_u, 172 constraints=constraints) 173 try: 174 self.listener(self.owner, it, data) 175 except Exception: 176 logError("\nERROR: Exception occured during an inverse solver listener callback:\n\n") 177 traceback.print_exc(file=sys.stdout) 178 raise 179 180 181 class TaucIterationListenerAdaptor(PISM.IP_SSATaucTaoTikhonovProblemListener): 182 183 """Adaptor converting calls to a C++ :cpp:class:`IP_SSATaucTaoTikhonovProblemListener` 184 on to a standard python-based listener. Used internally by 185 :class:`InvSSATaucSolver_Tikhonov`. I.e. don't make one of these for yourself.""" 186 187 def __init__(self, owner, listener): 188 """:param owner: The :class:`InvSSATaucSolver_Tikhonov` that constructed us 189 :param listener: The python-based listener. 190 """ 191 PISM.IP_SSATaucTaoTikhonovProblemListener.__init__(self) 192 self.owner = owner 193 self.listener = listener 194 195 def iteration(self, problem, eta, it, objVal, penaltyVal, d, diff_d, grad_d, u, diff_u, grad_u, grad): 196 """Called during IP_SSATaucTaoTikhonovProblem iterations. Gathers together the long list of arguments 197 into a dictionary and passes it along in a standard form to the python listener.""" 198 data = Bunch(tikhonov_penalty=eta, JDesign=objVal, JState=penaltyVal, 199 zeta=d, zeta_step=diff_d, grad_JDesign=grad_d, 200 u=u, residual=diff_u, grad_JState=grad_u, grad_JTikhonov=grad) 201 try: 202 self.listener(self.owner, it, data) 203 except Exception: 204 logError("\nERROR: Exception occured during an inverse solver listener callback:\n\n") 205 traceback.print_exc(file=sys.stdout) 206 raise 207 208 209 class HardavIterationListenerAdaptor(PISM.IP_SSAHardavTaoTikhonovProblemListener): 210 211 """Adaptor converting calls to a C++ :cpp:class:`IP_SSATaucTaoTikhonovProblemListener` 212 on to a standard python-based listener. Used internally by 213 :class:`InvSSATaucSolver_Tikhonov`. I.e. don't make one of these for yourself.""" 214 215 def __init__(self, owner, listener): 216 """:param owner: The :class:`InvSSATaucSolver_Tikhonov` that constructed us 217 :param listener: The python-based listener. 218 """ 219 PISM.IP_SSAHardavTaoTikhonovProblemListener.__init__(self) 220 self.owner = owner 221 self.listener = listener 222 223 def iteration(self, problem, eta, it, objVal, penaltyVal, d, diff_d, grad_d, u, diff_u, grad_u, grad): 224 """Called during IP_SSATaucTaoTikhonovProblem iterations. Gathers together the long list of arguments 225 into a dictionary and passes it along in a standard form to the python listener.""" 226 data = Bunch(tikhonov_penalty=eta, JDesign=objVal, JState=penaltyVal, 227 zeta=d, zeta_step=diff_d, grad_JDesign=grad_d, 228 u=u, residual=diff_u, grad_JState=grad_u, grad_JTikhonov=grad) 229 try: 230 self.listener(self.owner, it, data) 231 except Exception: 232 logError("\nERROR: Exception occured during an inverse solver listener callback:\n\n") 233 traceback.print_exc(file=sys.stdout) 234 raise