Source code for pyomo.contrib.cp.transform.logical_to_disjunctive_walker

#  ___________________________________________________________________________
#
#  Pyomo: Python Optimization Modeling Objects
#  Copyright (c) 2008-2024
#  National Technology and Engineering Solutions of Sandia, LLC
#  Under the terms of Contract DE-NA0003525 with National Technology and
#  Engineering Solutions of Sandia, LLC, the U.S. Government retains certain
#  rights in this software.
#  This software is distributed under the 3-clause BSD License.
#  ___________________________________________________________________________

from pyomo.common.collections import ComponentMap
from pyomo.common.errors import MouseTrap
from pyomo.core.expr.expr_common import ExpressionType
from pyomo.core.expr.visitor import StreamBasedExpressionVisitor
import pyomo.core.expr as EXPR
from pyomo.core.base import (
    Binary,
    Constraint,
    ConstraintList,
    NonNegativeIntegers,
    VarList,
    value,
)
import pyomo.core.base.boolean_var as BV
from pyomo.core.base.expression import ScalarExpression, ExpressionData
from pyomo.core.base.param import ScalarParam, ParamData
from pyomo.core.base.var import ScalarVar, VarData
from pyomo.gdp.disjunct import AutoLinkedBooleanVar, Disjunct, Disjunction


def _dispatch_boolean_var(visitor, node):
    if node not in visitor.boolean_to_binary_map:
        binary = node.get_associated_binary()
        if binary is not None:
            visitor.boolean_to_binary_map[node] = binary
        else:
            z = visitor.z_vars.add()
            visitor.boolean_to_binary_map[node] = z
            node.associate_binary_var(z)
    if node.fixed:
        visitor.boolean_to_binary_map[node].fixed = True
        visitor.boolean_to_binary_map[node].set_value(
            int(node.value) if node.value is not None else None, skip_validation=True
        )
    return False, visitor.boolean_to_binary_map[node]


def _dispatch_var(visitor, node):
    return False, node


def _dispatch_param(visitor, node):
    return False, node


def _dispatch_expression(visitor, node):
    return False, node.expr


def _before_relational_expr(visitor, node):
    raise MouseTrap(
        "The RelationalExpression '%s' was used as a Boolean term "
        "in a logical proposition. This is not yet supported "
        "when transforming to disjunctive form." % node
    )


def _dispatch_not(visitor, node, a):
    # z == !a
    if a not in visitor.expansions:
        z = visitor.z_vars.add()
        visitor.constraints.add(z == 1 - a)
        visitor.expansions[a] = z
    return visitor.expansions[a]


def _dispatch_implication(visitor, node, a, b):
    # z == !a v b
    return _dispatch_or(visitor, node, 1 - a, b)


def _dispatch_equivalence(visitor, node, a, b):
    # z == (!a v b) ^ (a v !b)
    return _dispatch_and(
        visitor,
        node,
        _dispatch_or(visitor, node, 1 - a, b),
        _dispatch_or(visitor, node, a, 1 - b),
    )


def _dispatch_and(visitor, node, *args):
    # z == a ^ b ^ ...
    z = visitor.z_vars.add()
    for arg in args:
        visitor.constraints.add(arg >= z)
    visitor.constraints.add(len(args) - sum(args) >= 1 - z)
    return z


def _dispatch_or(visitor, node, *args):
    # z == a v b v ...
    # (!z v a v b v ...) ^ (z v !a) ^ (z v !b) ^ ...
    z = visitor.z_vars.add()
    visitor.constraints.add((1 - z) + sum(args) >= 1)
    for arg in args:
        visitor.constraints.add(z + (1 - arg) >= 1)
    return z


def _dispatch_xor(visitor, node, a, b):
    # z == a XOR b
    # This is a special case of exactly
    return _dispatch_exactly(visitor, node, 1, a, b)


def _get_integer_value(n, node):
    if n.__class__ in EXPR.native_numeric_types and int(n) == n:
        return n
    if n.__class__ not in EXPR.native_types:
        if n.is_potentially_variable():
            # [ESJ 11/22]: This is probably worth supporting sometime, but right
            # now we are abiding by what docplex allows in their 'count'
            # function. Part of supporting this will be making sure we catch
            # strict inequalities in the GDP transformations. Because if we
            # don't know that n is integer-valued we will be forced to write
            # strict inequalities instead of incrememting or decrementing by 1
            # in the disjunctions.
            raise MouseTrap(
                "The first argument '%s' to '%s' is potentially variable. "
                "This may be a mathematically coherent expression; However "
                "it is not yet supported to convert it to a disjunctive "
                "program." % (n, node)
            )
        else:
            return n
    raise ValueError(
        "The first argument to '%s' must be an integer.\n\tRecieved: %s" % (node, n)
    )


def _dispatch_exactly(visitor, node, *args):
    # z = sum(args[1:]) == args[0]
    # This is currently implemented as:
    # [sum(args[1:] = n] v [[sum(args[1:]) < n] v [sum(args[1:]) > n]]
    M = len(args) - 1
    n = _get_integer_value(args[0], node)
    sum_expr = sum(args[1:])
    equality_disj = visitor.disjuncts[len(visitor.disjuncts)]
    equality_disj.constraint = Constraint(expr=sum_expr == n)
    inequality_disj = visitor.disjuncts[len(visitor.disjuncts)]
    inequality_disj.disjunction = Disjunction(
        expr=[[sum_expr <= n - 1], [sum_expr >= n + 1]]
    )
    visitor.disjunctions[len(visitor.disjunctions)] = [equality_disj, inequality_disj]
    return equality_disj.indicator_var.get_associated_binary()


def _dispatch_atleast(visitor, node, *args):
    # z = sum[args[1:] >= n
    # This is implemented as:
    # [sum(args[1:] >= n] v [sum(args[1:] < n]
    n = _get_integer_value(args[0], node)
    sum_expr = sum(args[1:])
    atleast_disj = visitor.disjuncts[len(visitor.disjuncts)]
    less_disj = visitor.disjuncts[len(visitor.disjuncts)]
    atleast_disj.constraint = Constraint(expr=sum_expr >= n)
    less_disj.constraint = Constraint(expr=sum_expr <= n - 1)
    visitor.disjunctions[len(visitor.disjunctions)] = [atleast_disj, less_disj]
    return atleast_disj.indicator_var.get_associated_binary()


def _dispatch_atmost(visitor, node, *args):
    # z = sum[args[1:] <= n
    # This is implemented as:
    # [sum(args[1:] <= n] v [sum(args[1:] > n]
    n = _get_integer_value(args[0], node)
    sum_expr = sum(args[1:])
    atmost_disj = visitor.disjuncts[len(visitor.disjuncts)]
    more_disj = visitor.disjuncts[len(visitor.disjuncts)]
    atmost_disj.constraint = Constraint(expr=sum_expr <= n)
    more_disj.constraint = Constraint(expr=sum_expr >= n + 1)
    visitor.disjunctions[len(visitor.disjunctions)] = [atmost_disj, more_disj]
    return atmost_disj.indicator_var.get_associated_binary()


_operator_dispatcher = {}
_operator_dispatcher[EXPR.ImplicationExpression] = _dispatch_implication
_operator_dispatcher[EXPR.EquivalenceExpression] = _dispatch_equivalence
_operator_dispatcher[EXPR.NotExpression] = _dispatch_not
_operator_dispatcher[EXPR.AndExpression] = _dispatch_and
_operator_dispatcher[EXPR.OrExpression] = _dispatch_or
_operator_dispatcher[EXPR.XorExpression] = _dispatch_xor
_operator_dispatcher[EXPR.ExactlyExpression] = _dispatch_exactly
_operator_dispatcher[EXPR.AtLeastExpression] = _dispatch_atleast
_operator_dispatcher[EXPR.AtMostExpression] = _dispatch_atmost

_before_child_dispatcher = {}
_before_child_dispatcher[BV.ScalarBooleanVar] = _dispatch_boolean_var
_before_child_dispatcher[BV.BooleanVarData] = _dispatch_boolean_var
_before_child_dispatcher[AutoLinkedBooleanVar] = _dispatch_boolean_var
_before_child_dispatcher[ParamData] = _dispatch_param
_before_child_dispatcher[ScalarParam] = _dispatch_param
# for the moment, these are all just so we can get good error messages when we
# don't handle them:
_before_child_dispatcher[ScalarVar] = _dispatch_var
_before_child_dispatcher[VarData] = _dispatch_var
_before_child_dispatcher[ExpressionData] = _dispatch_expression
_before_child_dispatcher[ScalarExpression] = _dispatch_expression


[docs] class LogicalToDisjunctiveVisitor(StreamBasedExpressionVisitor): """Converts BooleanExpressions to Linear (MIP) representation This converter eschews conjunctive normal form, and instead follows the well-trodden MINLP path of factorable programming. """
[docs] def __init__(self): super().__init__() self.z_vars = VarList(domain=Binary) self.z_vars.construct() self.constraints = ConstraintList() self.disjuncts = Disjunct(NonNegativeIntegers, concrete=True) self.disjunctions = Disjunction(NonNegativeIntegers) self.disjunctions.construct() self.expansions = ComponentMap() self.boolean_to_binary_map = ComponentMap()
def initializeWalker(self, expr): walk, result = self.beforeChild(None, expr, 0) if not walk: return False, self.finalizeResult(result) return True, expr def beforeChild(self, node, child, child_idx): if child.__class__ in EXPR.native_types: if child.__class__ is bool: # If we encounter a bool, we are going to need to treat it as # binary explicitly because we are finally pedantic enough in the # expression system to not allow some of the mixing we will need # (like summing a LinearExpression with a bool) return False, int(child) return False, child if child.is_numeric_type(): # Just pass it through, we'll figure it out later return False, child if child.is_expression_type(ExpressionType.RELATIONAL): # Eventually we'll handle these. Right now we set a MouseTrap return _before_relational_expr(self, child) if not child.is_expression_type() or child.is_named_expression_type(): return _before_child_dispatcher[child.__class__](self, child) return True, None def exitNode(self, node, data): return _operator_dispatcher[node.__class__](self, node, *data) def finalizeResult(self, result): # This LogicalExpression must evaluate to True (but note that we cannot # fix this variable to 1 since this logical expression could be living # on a Disjunct and later need to be relaxed.) self.constraints.add(result >= 1) return result