Source code for pyomo.core.expr.compare

#  ___________________________________________________________________________
#
#  Pyomo: Python Optimization Modeling Objects
#  Copyright (c) 2008-2025
#  National Technology and Engineering Solutions of Sandia, LLC
#  Under the terms of Contract DE-NA0003525 with National Technology and
#  Engineering Solutions of Sandia, LLC, the U.S. Government retains certain
#  rights in this software.
#  This software is distributed under the 3-clause BSD License.
#  ___________________________________________________________________________
import collections
from .visitor import StreamBasedExpressionVisitor
from .numvalue import nonpyomo_leaf_types
from pyomo.core.expr import (
    LinearExpression,
    MonomialTermExpression,
    SumExpression,
    ExpressionBase,
    ProductExpression,
    DivisionExpression,
    PowExpression,
    NegationExpression,
    UnaryFunctionExpression,
    ExternalFunctionExpression,
    NPV_ProductExpression,
    NPV_DivisionExpression,
    NPV_PowExpression,
    NPV_SumExpression,
    NPV_NegationExpression,
    NPV_UnaryFunctionExpression,
    NPV_ExternalFunctionExpression,
    Expr_ifExpression,
    AbsExpression,
    NPV_AbsExpression,
    NumericValue,
    RangedExpression,
    InequalityExpression,
    EqualityExpression,
    GetItemExpression,
)
from typing import List
from pyomo.common.collections import Sequence
from pyomo.common.errors import PyomoException
from pyomo.common.formatting import tostr
from pyomo.common.numeric_types import native_types


[docs] def handle_expression(node: ExpressionBase, pn: List): pn.append((type(node), node.nargs())) return node.args
[docs] def handle_named_expression(node, pn: List, include_named_exprs=True): if include_named_exprs: pn.append((type(node), 1)) return (node.expr,)
[docs] def handle_unary_expression(node: UnaryFunctionExpression, pn: List): pn.append((type(node), 1, node.getname())) return node.args
[docs] def handle_external_function_expression(node: ExternalFunctionExpression, pn: List): pn.append((type(node), node.nargs(), node._fcn)) return node.args
[docs] def handle_sequence(node: collections.abc.Sequence, pn: List): pn.append((collections.abc.Sequence, len(node))) return list(node)
[docs] def handle_inequality(node: collections.abc.Sequence, pn: List): pn.append((type(node), node.nargs(), node.strict)) return node.args
def _generic_expression_handler(): return handle_expression handler = collections.defaultdict(_generic_expression_handler) handler[UnaryFunctionExpression] = handle_unary_expression handler[NPV_UnaryFunctionExpression] = handle_unary_expression handler[ExternalFunctionExpression] = handle_external_function_expression handler[NPV_ExternalFunctionExpression] = handle_external_function_expression handler[AbsExpression] = handle_unary_expression handler[NPV_AbsExpression] = handle_unary_expression handler[InequalityExpression] = handle_inequality handler[RangedExpression] = handle_inequality handler[list] = handle_sequence
[docs] class PrefixVisitor(StreamBasedExpressionVisitor):
[docs] def __init__(self, include_named_exprs=True): super().__init__() self._result = None self._include_named_exprs = include_named_exprs
def initializeWalker(self, expr): self._result = [] return True, None def enterNode(self, node): ntype = type(node) if ntype in nonpyomo_leaf_types: self._result.append(node) return tuple(), None if ntype in handler: return handler[ntype](node, self._result), None if hasattr(node, 'is_expression_type'): if node.is_expression_type(): if node.is_named_expression_type(): return ( handle_named_expression( node, self._result, self._include_named_exprs ), None, ) else: return handler[ntype](node, self._result), None elif hasattr(node, '__len__'): handler[ntype] = handle_sequence return handle_sequence(node, self._result), None self._result.append(node) return tuple(), None def finalizeResult(self, result): ans = self._result self._result = None return ans
[docs] def convert_expression_to_prefix_notation(expr, include_named_exprs=True): """ This function converts pyomo expressions to a list that looks very much like prefix notation. The result can be used in equality comparisons to compare expression trees. Note that the data structure returned by this function might be changed in the future. However, we will maintain that the result can be used in equality comparisons. Also note that the result should really only be used in equality comparisons if the equality comparison is expected to return True. If the expressions being compared are expected to be different, then the equality comparison will often result in an error rather than returning False. m = ConcreteModel() m.x = Var() m.y = Var() e1 = m.x * m.y e2 = m.x * m.y e3 = m.x + m.y convert_expression_to_prefix_notation(e1) == convert_expression_to_prefix_notation(e2) # True convert_expression_to_prefix_notation(e1) == convert_expression_to_prefix_notation(e3) # Error However, the compare_expressions function can be used: compare_expressions(e1, e2) # True compare_expressions(e1, e3) # False Parameters ---------- expr: NumericValue A Pyomo expression, Var, or Param Returns ------- prefix_notation: list The expression in prefix notation """ visitor = PrefixVisitor(include_named_exprs=include_named_exprs) return visitor.walk_expression(expr)
[docs] def compare_expressions(expr1, expr2, include_named_exprs=True): """Returns True if 2 expression trees are identical, False otherwise. Parameters ---------- expr1: NumericValue A Pyomo Var, Param, or expression expr2: NumericValue A Pyomo Var, Param, or expression include_named_exprs: bool If False, then named expressions will be ignored. In other words, this function will return True if one expression has a named expression and the other does not as long as the rest of the expression trees are identical. Returns ------- res: bool A bool indicating whether or not the expressions are identical. """ pn1 = convert_expression_to_prefix_notation( expr1, include_named_exprs=include_named_exprs ) pn2 = convert_expression_to_prefix_notation( expr2, include_named_exprs=include_named_exprs ) try: res = pn1 == pn2 except (PyomoException, AttributeError): res = False return res
[docs] def assertExpressionsEqual(test, a, b, include_named_exprs=True, places=None): """unittest-based assertion for comparing expressions This converts the expressions `a` and `b` into prefix notation and then compares the resulting lists. Parameters ---------- test: unittest.TestCase The unittest `TestCase` class that is performing the test. a: ExpressionBase or native type b: ExpressionBase or native type include_named_exprs : bool If True (the default), the comparison expands all named expressions when generating the prefix notation places : int Number of decimal places required for equality of floating point numbers in the expression. If None (the default), the expressions must be exactly equal. """ prefix_a = convert_expression_to_prefix_notation(a, include_named_exprs) prefix_b = convert_expression_to_prefix_notation(b, include_named_exprs) try: test.assertEqual(len(prefix_a), len(prefix_b)) for _a, _b in zip(prefix_a, prefix_b): test.assertIs(_a.__class__, _b.__class__) # If _a is nan, check _b is nan if _a != _a: test.assertTrue(_b != _b) else: if places is None: test.assertEqual(_a, _b) else: test.assertAlmostEqual(_a, _b, places=places) except (PyomoException, AssertionError): test.fail( f"Expressions not equal:\n\t" f"{tostr(prefix_a)}\n\t!=\n\t{tostr(prefix_b)}" )
[docs] def assertExpressionsStructurallyEqual( test, a, b, include_named_exprs=True, places=None ): """unittest-based assertion for comparing expressions This converts the expressions `a` and `b` into prefix notation and then compares the resulting lists. Operators and (non-native type) leaf nodes in the prefix representation are converted to strings before comparing (so that things like variables can be compared across clones or pickles) Parameters ---------- test: unittest.TestCase The unittest `TestCase` class that is performing the test. a: ExpressionBase or native type b: ExpressionBase or native type include_named_exprs : bool If True (the default), the comparison expands all named expressions when generating the prefix notation places : int Number of decimal places required for equality of floating point numbers in the expression. If None (the default), the expressions must be exactly equal. """ prefix_a = convert_expression_to_prefix_notation(a, include_named_exprs) prefix_b = convert_expression_to_prefix_notation(b, include_named_exprs) # Convert leaf nodes and operators to their string equivalents for prefix in (prefix_a, prefix_b): for i, v in enumerate(prefix): if type(v) in native_types: continue if type(v) is tuple: # This is an expression node. Most expression nodes are # 2-tuples (node type, nargs), but some are 3-tuples # with supplemental data. The biggest problem is # external functions, where the third element is the # external function. We need to convert that to a # string to support "structural" comparisons. if len(v) == 3: prefix[i] = v[:2] + (str(v[2]),) continue # This should be a leaf node (Var, mutable Param, etc.). # Convert to string to support "structural" comparison # (e.g., across clones) prefix[i] = str(v) try: test.assertEqual(len(prefix_a), len(prefix_b)) for _a, _b in zip(prefix_a, prefix_b): if _a.__class__ not in native_types and _b.__class__ not in native_types: test.assertIs(_a.__class__, _b.__class__) if _a != _a: test.assertTrue(_b != _b) else: if places is None: test.assertEqual(_a, _b) else: test.assertAlmostEqual(_a, _b, places=places) except (PyomoException, AssertionError): test.fail( f"Expressions not structurally equal:\n\t" f"{tostr(prefix_a)}\n\t!=\n\t{tostr(prefix_b)}" )