Source code for pyomo.repn.linear

#  ___________________________________________________________________________
#
#  Pyomo: Python Optimization Modeling Objects
#  Copyright (c) 2008-2024
#  National Technology and Engineering Solutions of Sandia, LLC
#  Under the terms of Contract DE-NA0003525 with National Technology and
#  Engineering Solutions of Sandia, LLC, the U.S. Government retains certain
#  rights in this software.
#  This software is distributed under the 3-clause BSD License.
#  ___________________________________________________________________________

import logging
import sys
from operator import itemgetter
from itertools import filterfalse

from pyomo.common.deprecation import deprecation_warning
from pyomo.common.numeric_types import (
    native_types,
    native_numeric_types,
    native_complex_types,
)
from pyomo.core.expr.numeric_expr import (
    NegationExpression,
    ProductExpression,
    DivisionExpression,
    PowExpression,
    AbsExpression,
    UnaryFunctionExpression,
    Expr_ifExpression,
    MonomialTermExpression,
    LinearExpression,
    SumExpression,
    ExternalFunctionExpression,
    mutable_expression,
)
from pyomo.core.expr.relational_expr import (
    EqualityExpression,
    InequalityExpression,
    RangedExpression,
)
from pyomo.core.expr.visitor import StreamBasedExpressionVisitor, _EvaluationVisitor
from pyomo.core.expr import is_fixed, value
from pyomo.core.base.expression import Expression
import pyomo.core.kernel as kernel
from pyomo.repn.util import (
    BeforeChildDispatcher,
    ExitNodeDispatcher,
    ExprType,
    FileDeterminism,
    FileDeterminism_to_SortComponents,
    InvalidNumber,
    OrderedVarRecorder,
    VarRecorder,
    apply_node_operation,
    complex_number_error,
    initialize_exit_node_dispatcher,
    nan,
    sum_like_expression_types,
)

logger = logging.getLogger(__name__)

_CONSTANT = ExprType.CONSTANT
_LINEAR = ExprType.LINEAR
_GENERAL = ExprType.GENERAL


def _inv2str(val):
    return f"{val._str() if hasattr(val, '_str') else val}"


def _merge_dict(dest_dict, mult, src_dict):
    if mult == 1:
        for vid, coef in src_dict.items():
            if vid in dest_dict:
                dest_dict[vid] += coef
            else:
                dest_dict[vid] = coef
    else:
        for vid, coef in src_dict.items():
            if vid in dest_dict:
                dest_dict[vid] += mult * coef
            else:
                dest_dict[vid] = mult * coef


[docs] class LinearRepn(object): __slots__ = ("multiplier", "constant", "linear", "nonlinear")
[docs] def __init__(self): self.multiplier = 1 self.constant = 0 self.linear = {} self.nonlinear = None
def __str__(self): return ( f"LinearRepn(mult={self.multiplier}, const={self.constant}, " f"linear={self.linear}, nonlinear={self.nonlinear})" ) def __repr__(self): return str(self) def walker_exitNode(self): if self.nonlinear is not None: return _GENERAL, self elif self.linear: return _LINEAR, self else: return _CONSTANT, self.multiplier * self.constant def duplicate(self): ans = self.__class__.__new__(self.__class__) ans.multiplier = self.multiplier ans.constant = self.constant ans.linear = dict(self.linear) ans.nonlinear = self.nonlinear return ans def to_expression(self, visitor): if self.nonlinear is not None: # We want to start with the nonlinear term (and use # assignment) in case the term is a non-numeric node (like a # relational expression) ans = self.nonlinear else: ans = 0 if self.linear: var_map = visitor.var_map with mutable_expression() as e: for vid, coef in self.linear.items(): if coef: e += coef * var_map[vid] if e.nargs() > 1: ans += e elif e.nargs() == 1: ans += e.arg(0) if self.constant: ans += self.constant if self.multiplier != 1: ans *= self.multiplier return ans
[docs] def append(self, other): """Append a child result from acceptChildResult Notes ----- This method assumes that the operator was "+". It is implemented so that we can directly use a LinearRepn() as a `data` object in the expression walker (thereby allowing us to use the default implementation of acceptChildResult [which calls `data.append()`] and avoid the function call for a custom callback). """ # Note that self.multiplier will always be 1 (we only call append() # within a sum, so there is no opportunity for self.multiplier to # change). Omitting the assertion for efficiency. # assert self.multiplier == 1 _type, other = other if _type is _CONSTANT: self.constant += other return mult = other.multiplier if not mult: # 0 * other, so there is nothing to add/change about # self. We can just exit now. return if other.constant: self.constant += mult * other.constant if other.linear: _merge_dict(self.linear, mult, other.linear) if other.nonlinear is not None: if mult != 1: nl = mult * other.nonlinear else: nl = other.nonlinear if self.nonlinear is None: self.nonlinear = nl else: self.nonlinear += nl
[docs] def to_expression(visitor, arg): if arg[0] is _CONSTANT: return arg[1] else: return arg[1].to_expression(visitor)
# # NEGATION handlers # def _handle_negation_constant(visitor, node, arg): return (_CONSTANT, -1 * arg[1]) def _handle_negation_ANY(visitor, node, arg): arg[1].multiplier *= -1 return arg # # PRODUCT handlers # def _handle_product_constant_constant(visitor, node, arg1, arg2): ans = arg1[1] * arg2[1] if ans != ans: if not arg1[1] or not arg2[1]: a = _inv2str(arg1[1]) b = _inv2str(arg2[1]) deprecation_warning( f"Encountered {a}*{b} in expression tree. " "Mapping the NaN result to 0 for compatibility " "with the lp_v1 writer. In the future, this NaN " "will be preserved/emitted to comply with IEEE-754.", version='6.6.0', ) return _CONSTANT, 0 return _CONSTANT, ans def _handle_product_constant_ANY(visitor, node, arg1, arg2): arg2[1].multiplier *= arg1[1] return arg2 def _handle_product_ANY_constant(visitor, node, arg1, arg2): arg1[1].multiplier *= arg2[1] return arg1 def _handle_product_nonlinear(visitor, node, arg1, arg2): ans = visitor.Result() if not visitor.expand_nonlinear_products: ans.nonlinear = to_expression(visitor, arg1) * to_expression(visitor, arg2) return _GENERAL, ans # We are multiplying (A + Bx + C(x)) * (A + Bx + C(x)) _, x1 = arg1 _, x2 = arg2 ans.multiplier = x1.multiplier * x2.multiplier x1.multiplier = x2.multiplier = 1 # x1.const * x2.const [AA] ans.constant = x1.constant * x2.constant # x1.linear * x2.const [BA] + x1.const * x2.linear [AB] if x2.constant: c = x2.constant if c == 1: ans.linear = dict(x1.linear) else: ans.linear = {vid: c * coef for vid, coef in x1.linear.items()} if x1.constant: _merge_dict(ans.linear, x1.constant, x2.linear) ans.nonlinear = 0 if x1.constant and x2.nonlinear is not None: # [AC] ans.nonlinear += x1.constant * x2.nonlinear if x1.nonlinear is not None: # [CA] + [CB] + [CC] ans.nonlinear += x1.nonlinear * to_expression(visitor, arg2) if x1.linear: # [BB] + [BC] x1.constant = 0 x1.nonlinear = None x2.constant = 0 ans.nonlinear += to_expression(visitor, arg1) * to_expression(visitor, arg2) return _GENERAL, ans # # DIVISION handlers # def _handle_division_constant_constant(visitor, node, arg1, arg2): return _CONSTANT, apply_node_operation(node, (arg1[1], arg2[1])) def _handle_division_ANY_constant(visitor, node, arg1, arg2): arg1[1].multiplier = apply_node_operation(node, (arg1[1].multiplier, arg2[1])) return arg1 def _handle_division_nonlinear(visitor, node, arg1, arg2): ans = visitor.Result() ans.nonlinear = to_expression(visitor, arg1) / to_expression(visitor, arg2) return _GENERAL, ans # # EXPONENTIATION handlers # def _handle_pow_constant_constant(visitor, node, arg1, arg2): ans = apply_node_operation(node, (arg1[1], arg2[1])) if ans.__class__ in native_complex_types: ans = complex_number_error(ans, visitor, node) return _CONSTANT, ans def _handle_pow_ANY_constant(visitor, node, arg1, arg2): _, exp = arg2 if exp == 1: return arg1 elif exp > 1 and exp <= visitor.max_exponential_expansion and int(exp) == exp: _type, _arg = arg1 ans = _type, _arg.duplicate() for i in range(1, int(exp)): ans = visitor.exit_node_dispatcher[(ProductExpression, ans[0], _type)]( visitor, None, ans, (_type, _arg.duplicate()) ) return ans elif exp == 0: return _CONSTANT, 1 else: return _handle_pow_nonlinear(visitor, node, arg1, arg2) def _handle_pow_nonlinear(visitor, node, arg1, arg2): ans = visitor.Result() ans.nonlinear = to_expression(visitor, arg1) ** to_expression(visitor, arg2) return _GENERAL, ans # # ABS and UNARY handlers # def _handle_unary_constant(visitor, node, arg): ans = apply_node_operation(node, (arg[1],)) # Unary includes sqrt() which can return complex numbers if ans.__class__ in native_complex_types: ans = complex_number_error(ans, visitor, node) return _CONSTANT, ans def _handle_unary_nonlinear(visitor, node, arg): ans = visitor.Result() ans.nonlinear = node.create_node_with_local_data((to_expression(visitor, arg),)) return _GENERAL, ans # # NAMED EXPRESSION handlers # def _handle_named_constant(visitor, node, arg1): # Record this common expression visitor.subexpression_cache[id(node)] = arg1 return arg1 def _handle_named_ANY(visitor, node, arg1): # Record this common expression visitor.subexpression_cache[id(node)] = arg1 _type, arg1 = arg1 return _type, arg1.duplicate() # # EXPR_IF handlers # def _handle_expr_if_const(visitor, node, arg1, arg2, arg3): _type, _test = arg1 assert _type is _CONSTANT if _test: if _test != _test or _test.__class__ is InvalidNumber: # nan return _handle_expr_if_nonlinear(visitor, node, arg1, arg2, arg3) return arg2 else: return arg3 def _handle_expr_if_nonlinear(visitor, node, arg1, arg2, arg3): # Note: guaranteed that arg1 is not _CONSTANT ans = visitor.Result() ans.nonlinear = Expr_ifExpression( ( to_expression(visitor, arg1), to_expression(visitor, arg2), to_expression(visitor, arg3), ) ) return _GENERAL, ans # # Relational expression handlers # def _handle_equality_const(visitor, node, arg1, arg2): # It is exceptionally likely that if we get here, one of the # arguments is an InvalidNumber args, causes = InvalidNumber.parse_args(arg1[1], arg2[1]) try: ans = args[0] == args[1] except: ans = False causes.append(str(sys.exc_info()[1])) if causes: ans = InvalidNumber(ans, causes) return _CONSTANT, ans def _handle_equality_general(visitor, node, arg1, arg2): ans = visitor.Result() ans.nonlinear = EqualityExpression( (to_expression(visitor, arg1), to_expression(visitor, arg2)) ) return _GENERAL, ans def _handle_inequality_const(visitor, node, arg1, arg2): # It is exceptionally likely that if we get here, one of the # arguments is an InvalidNumber args, causes = InvalidNumber.parse_args(arg1[1], arg2[1]) try: ans = args[0] <= args[1] except: ans = False causes.append(str(sys.exc_info()[1])) if causes: ans = InvalidNumber(ans, causes) return _CONSTANT, ans def _handle_inequality_general(visitor, node, arg1, arg2): ans = visitor.Result() ans.nonlinear = InequalityExpression( (to_expression(visitor, arg1), to_expression(visitor, arg2)), node.strict ) return _GENERAL, ans def _handle_ranged_const(visitor, node, arg1, arg2, arg3): # It is exceptionally likely that if we get here, one of the # arguments is an InvalidNumber args, causes = InvalidNumber.parse_args(arg1[1], arg2[1], arg3[1]) try: ans = args[0] <= args[1] <= args[2] except: ans = False causes.append(str(sys.exc_info()[1])) if causes: ans = InvalidNumber(ans, causes) return _CONSTANT, ans def _handle_ranged_general(visitor, node, arg1, arg2, arg3): ans = visitor.Result() ans.nonlinear = RangedExpression( ( to_expression(visitor, arg1), to_expression(visitor, arg2), to_expression(visitor, arg3), ), node.strict, ) return _GENERAL, ans
[docs] def define_exit_node_handlers(_exit_node_handlers=None): if _exit_node_handlers is None: _exit_node_handlers = {} _exit_node_handlers[NegationExpression] = { None: _handle_negation_ANY, (_CONSTANT,): _handle_negation_constant, } _exit_node_handlers[ProductExpression] = { None: _handle_product_nonlinear, (_CONSTANT, _CONSTANT): _handle_product_constant_constant, (_CONSTANT, _LINEAR): _handle_product_constant_ANY, (_CONSTANT, _GENERAL): _handle_product_constant_ANY, (_LINEAR, _CONSTANT): _handle_product_ANY_constant, (_GENERAL, _CONSTANT): _handle_product_ANY_constant, } _exit_node_handlers[MonomialTermExpression] = _exit_node_handlers[ProductExpression] _exit_node_handlers[DivisionExpression] = { None: _handle_division_nonlinear, (_CONSTANT, _CONSTANT): _handle_division_constant_constant, (_LINEAR, _CONSTANT): _handle_division_ANY_constant, (_GENERAL, _CONSTANT): _handle_division_ANY_constant, } _exit_node_handlers[PowExpression] = { None: _handle_pow_nonlinear, (_CONSTANT, _CONSTANT): _handle_pow_constant_constant, (_LINEAR, _CONSTANT): _handle_pow_ANY_constant, (_GENERAL, _CONSTANT): _handle_pow_ANY_constant, } _exit_node_handlers[UnaryFunctionExpression] = { None: _handle_unary_nonlinear, (_CONSTANT,): _handle_unary_constant, } _exit_node_handlers[AbsExpression] = _exit_node_handlers[UnaryFunctionExpression] _exit_node_handlers[Expression] = { None: _handle_named_ANY, (_CONSTANT,): _handle_named_constant, } _exit_node_handlers[Expr_ifExpression] = {None: _handle_expr_if_nonlinear} for j in (_CONSTANT, _LINEAR, _GENERAL): for k in (_CONSTANT, _LINEAR, _GENERAL): _exit_node_handlers[Expr_ifExpression][ _CONSTANT, j, k ] = _handle_expr_if_const _exit_node_handlers[EqualityExpression] = { None: _handle_equality_general, (_CONSTANT, _CONSTANT): _handle_equality_const, } _exit_node_handlers[InequalityExpression] = { None: _handle_inequality_general, (_CONSTANT, _CONSTANT): _handle_inequality_const, } _exit_node_handlers[RangedExpression] = { None: _handle_ranged_general, (_CONSTANT, _CONSTANT, _CONSTANT): _handle_ranged_const, } return _exit_node_handlers
[docs] class LinearBeforeChildDispatcher(BeforeChildDispatcher):
[docs] def __init__(self): # Special handling for external functions: will be handled # as terminal nodes from the point of view of the visitor self[ExternalFunctionExpression] = self._before_external # Special linear / summation expressions self[MonomialTermExpression] = self._before_monomial self[LinearExpression] = self._before_linear self[SumExpression] = self._before_general_expression
@staticmethod def _before_var(visitor, child): _id = id(child) if _id not in visitor.var_map: if child.fixed: return False, (_CONSTANT, visitor.check_constant(child.value, child)) visitor.var_recorder.add(child) ans = visitor.Result() ans.linear[_id] = 1 return False, (_LINEAR, ans) @staticmethod def _before_monomial(visitor, child): # # The following are performance optimizations for common # situations (Monomial terms and Linear expressions) # arg1, arg2 = child._args_ if arg1.__class__ not in native_types: try: arg1 = visitor.check_constant(visitor.evaluate(arg1), arg1) except (ValueError, ArithmeticError): return True, None # We want to check / update the var_map before processing "0" # coefficients so that we are consistent with what gets added to the # var_map (e.g., 0*x*y: y is processed by _before_var and will # always be added, but x is processed here) _id = id(arg2) if _id not in visitor.var_map: if arg2.fixed: return False, ( _CONSTANT, arg1 * visitor.check_constant(arg2.value, arg2), ) visitor.var_recorder.add(arg2) # Trap multiplication by 0 and nan. if not arg1: if arg2.fixed: arg2 = visitor.check_constant(arg2.value, arg2) if arg2 != arg2: deprecation_warning( f"Encountered {arg1}*{_inv2str(arg2)} in expression " "tree. Mapping the NaN result to 0 for compatibility " "with the lp_v1 writer. In the future, this NaN " "will be preserved/emitted to comply with IEEE-754.", version='6.6.0', ) return False, (_CONSTANT, arg1) ans = visitor.Result() ans.linear[_id] = arg1 return False, (_LINEAR, ans) @staticmethod def _before_linear(visitor, child): var_map = visitor.var_map ans = visitor.Result() const = 0 linear = ans.linear for arg in child.args: if arg.__class__ is MonomialTermExpression: arg1, arg2 = arg._args_ if arg1.__class__ not in native_types: try: arg1 = visitor.check_constant(visitor.evaluate(arg1), arg1) except (ValueError, ArithmeticError): return True, None # Trap multiplication by 0 and nan. if not arg1: if arg2.fixed: arg2 = visitor.check_constant(arg2.value, arg2) if arg2 != arg2: deprecation_warning( f"Encountered {arg1}*{_inv2str(arg2)} in expression " "tree. Mapping the NaN result to 0 for compatibility " "with the lp_v1 writer. In the future, this NaN " "will be preserved/emitted to comply with IEEE-754.", version='6.6.0', ) continue _id = id(arg2) if _id not in var_map: if arg2.fixed: const += arg1 * visitor.check_constant(arg2.value, arg2) continue visitor.var_recorder.add(arg2) linear[_id] = arg1 elif _id in linear: linear[_id] += arg1 else: linear[_id] = arg1 elif arg.__class__ in native_numeric_types: const += arg elif arg.is_variable_type(): _id = id(arg) if _id not in var_map: if arg.fixed: const += visitor.check_constant(arg.value, arg) continue visitor.var_recorder.add(arg) linear[_id] = 1 elif _id in linear: linear[_id] += 1 else: linear[_id] = 1 else: try: const += visitor.check_constant(visitor.evaluate(arg), arg) except (ValueError, ArithmeticError): return True, None if linear: ans.constant = const return False, (_LINEAR, ans) else: return False, (_CONSTANT, const) @staticmethod def _before_named_expression(visitor, child): _id = id(child) if _id in visitor.subexpression_cache: _type, expr = visitor.subexpression_cache[_id] if _type is _CONSTANT: return False, (_type, expr) else: return False, (_type, expr.duplicate()) else: return True, None @staticmethod def _before_external(visitor, child): ans = visitor.Result() if all(is_fixed(arg) for arg in child.args): try: ans.constant = visitor.check_constant(visitor.evaluate(child), child) return False, (_CONSTANT, ans) except: pass ans.nonlinear = child return False, (_GENERAL, ans)
[docs] class LinearRepnVisitor(StreamBasedExpressionVisitor): Result = LinearRepn before_child_dispatcher = LinearBeforeChildDispatcher() exit_node_dispatcher = ExitNodeDispatcher( initialize_exit_node_dispatcher(define_exit_node_handlers()) ) expand_nonlinear_products = False max_exponential_expansion = 1
[docs] def __init__( self, subexpression_cache, var_map=None, var_order=None, sorter=None, var_recorder=None, ): super().__init__() self.subexpression_cache = subexpression_cache if any(_ is not None for _ in (var_map, var_order, sorter)): if var_recorder is not None: raise ValueError( "LinearRepnVisitor: cannot specify any of var_map, " "var_order, or sorter with var_recorder" ) deprecation_warning( "var_map, var_order, and sorter are deprecated arguments to " "LinearRepnVisitor(). Please pass the VarRecorder object directly.", version='6.8.1', ) var_recorder = OrderedVarRecorder(var_map, var_order, sorter) if var_recorder is None: var_recorder = VarRecorder( {}, FileDeterminism_to_SortComponents(FileDeterminism.ORDERED) ) self.var_recorder = var_recorder self.var_map = var_recorder.var_map self._eval_expr_visitor = _EvaluationVisitor(True) self.evaluate = self._eval_expr_visitor.dfs_postorder_stack
def check_constant(self, ans, obj): if ans.__class__ not in native_numeric_types: # None can be returned from uninitialized Var/Param objects if ans is None: return InvalidNumber( None, f"'{obj}' evaluated to a nonnumeric value '{ans}'" ) if ans.__class__ is InvalidNumber: return ans elif ans.__class__ in native_complex_types: return complex_number_error(ans, self, obj) else: # It is possible to get other non-numeric types. Most # common are bool and 1-element numpy.array(). We will # attempt to convert the value to a float before # proceeding. # # TODO: we should check bool and warn/error (while bool is # convertible to float in Python, they have very # different semantic meanings in Pyomo). try: ans = float(ans) except: return InvalidNumber( ans, f"'{obj}' evaluated to a nonnumeric value '{ans}'" ) if ans != ans: return InvalidNumber( nan, f"'{obj}' evaluated to a nonnumeric value '{ans}'" ) return ans def initializeWalker(self, expr): walk, result = self.beforeChild(None, expr, 0) if not walk: return False, self.finalizeResult(result) return True, expr def beforeChild(self, node, child, child_idx): return self.before_child_dispatcher[child.__class__](self, child) def enterNode(self, node): # SumExpression are potentially large nary operators. Directly # populate the result if node.__class__ in sum_like_expression_types: return node.args, self.Result() else: return node.args, [] def exitNode(self, node, data): if data.__class__ is self.Result: return data.walker_exitNode() # # General expressions... # return self.exit_node_dispatcher[(node.__class__, *map(itemgetter(0), data))]( self, node, *data ) def finalizeResult(self, result): ans = result[1] if ans.__class__ is self.Result: mult = ans.multiplier if mult == 1: # mult is identity: only thing to do is filter out zero coefficients zeros = list(filterfalse(itemgetter(1), ans.linear.items())) for vid, coef in zeros: del ans.linear[vid] elif not mult: # the mulltiplier has cleared out the entire expression. # Warn if this is suppressing a NaN (unusual, and # non-standard, but we will wait to remove this behavior # for the time being) if ans.constant != ans.constant or any( c != c for c in ans.linear.values() ): deprecation_warning( f"Encountered {mult}*nan in expression tree. " "Mapping the NaN result to 0 for compatibility " "with the lp_v1 writer. In the future, this NaN " "will be preserved/emitted to comply with IEEE-754.", version='6.6.0', ) return self.Result() else: # mult not in {0, 1}: factor it into the constant, # linear coefficients, and nonlinear term linear = ans.linear zeros = [] for vid, coef in linear.items(): if coef: linear[vid] = coef * mult else: zeros.append(vid) for vid in zeros: del linear[vid] if ans.nonlinear is not None: ans.nonlinear *= mult if ans.constant: ans.constant *= mult ans.multiplier = 1 return ans ans = self.Result() assert result[0] is _CONSTANT ans.constant = result[1] return ans