Implemented disjunctive normal form.
This commit is contained in:
parent
34496a7158
commit
4fc9b35c13
@ -120,11 +120,14 @@ class Expression
|
||||
|
||||
virtual Type expressionType() const = 0;
|
||||
|
||||
virtual ExpressionPointer copy();
|
||||
|
||||
ExpressionPointer normalized();
|
||||
virtual ExpressionPointer reduced();
|
||||
virtual ExpressionPointer negationNormalized();
|
||||
virtual ExpressionPointer prenex(Expression::Type lastQuantifierType = Expression::Type::Exists);
|
||||
virtual ExpressionPointer simplified();
|
||||
virtual ExpressionPointer disjunctionNormalized();
|
||||
ExpressionPointer negated();
|
||||
|
||||
virtual void print(std::ostream &ostream) const = 0;
|
||||
|
@ -22,6 +22,9 @@ class And: public NAry<And>
|
||||
static const Expression::Type ExpressionType = Expression::Type::And;
|
||||
|
||||
static const std::string Identifier;
|
||||
|
||||
public:
|
||||
ExpressionPointer disjunctionNormalized() override;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -32,6 +32,8 @@ class At: public ExpressionCRTP<At>
|
||||
public:
|
||||
At();
|
||||
|
||||
ExpressionPointer copy() override;
|
||||
|
||||
size_t timePoint() const;
|
||||
|
||||
void setArgument(ExpressionPointer argument);
|
||||
@ -41,6 +43,7 @@ class At: public ExpressionCRTP<At>
|
||||
ExpressionPointer negationNormalized() override;
|
||||
ExpressionPointer prenex(Expression::Type lastExpressionType) override;
|
||||
ExpressionPointer simplified() override;
|
||||
ExpressionPointer disjunctionNormalized() override;
|
||||
|
||||
void print(std::ostream &ostream) const override;
|
||||
|
||||
|
@ -29,12 +29,16 @@ class Binary: public ExpressionCRTP<Derived>
|
||||
ExpressionContext &expressionContext, ExpressionParser parseExpression);
|
||||
|
||||
public:
|
||||
ExpressionPointer copy() override;
|
||||
|
||||
void setArgument(size_t i, ExpressionPointer argument);
|
||||
const std::array<ExpressionPointer, 2> &arguments() const;
|
||||
|
||||
ExpressionPointer reduced() override;
|
||||
ExpressionPointer negationNormalized() override;
|
||||
ExpressionPointer prenex(Expression::Type lastExpressionType) override;
|
||||
ExpressionPointer simplified() override;
|
||||
ExpressionPointer disjunctionNormalized() override;
|
||||
|
||||
void print(std::ostream &ostream) const override;
|
||||
|
||||
@ -73,6 +77,19 @@ boost::intrusive_ptr<Derived> Binary<Derived>::parse(Context &context,
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Derived>
|
||||
ExpressionPointer Binary<Derived>::copy()
|
||||
{
|
||||
auto result = new Derived;
|
||||
|
||||
for (size_t i = 0; i < m_arguments.size(); i++)
|
||||
result->m_arguments[i] = m_arguments[i]->copy();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Derived>
|
||||
void Binary<Derived>::setArgument(size_t i, ExpressionPointer expression)
|
||||
{
|
||||
@ -130,6 +147,36 @@ inline ExpressionPointer Binary<Derived>::prenex(Expression::Type)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Derived>
|
||||
inline ExpressionPointer Binary<Derived>::simplified()
|
||||
{
|
||||
for (size_t i = 0; i < m_arguments.size(); i++)
|
||||
{
|
||||
BOOST_ASSERT(m_arguments[i]);
|
||||
|
||||
m_arguments[i] = m_arguments[i]->simplified();
|
||||
}
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Derived>
|
||||
inline ExpressionPointer Binary<Derived>::disjunctionNormalized()
|
||||
{
|
||||
for (size_t i = 0; i < m_arguments.size(); i++)
|
||||
{
|
||||
BOOST_ASSERT(m_arguments[i]);
|
||||
|
||||
m_arguments[i] = m_arguments[i]->disjunctionNormalized();
|
||||
}
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Derived>
|
||||
inline void Binary<Derived>::print(std::ostream &ostream) const
|
||||
{
|
||||
|
@ -29,6 +29,8 @@ class NAry: public ExpressionCRTP<Derived>
|
||||
ExpressionContext &expressionContext, ExpressionParser parseExpression);
|
||||
|
||||
public:
|
||||
ExpressionPointer copy() override;
|
||||
|
||||
void setArgument(size_t i, ExpressionPointer argument);
|
||||
void addArgument(ExpressionPointer argument);
|
||||
Expressions &arguments();
|
||||
@ -38,6 +40,7 @@ class NAry: public ExpressionCRTP<Derived>
|
||||
ExpressionPointer negationNormalized() override;
|
||||
ExpressionPointer prenex(Expression::Type lastExpressionType) override;
|
||||
ExpressionPointer simplified() override;
|
||||
ExpressionPointer disjunctionNormalized() override;
|
||||
|
||||
void print(std::ostream &ostream) const override;
|
||||
|
||||
@ -85,6 +88,21 @@ boost::intrusive_ptr<Derived> NAry<Derived>::parse(Context &context,
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Derived>
|
||||
ExpressionPointer NAry<Derived>::copy()
|
||||
{
|
||||
auto result = new Derived;
|
||||
|
||||
result->m_arguments.resize(m_arguments.size());
|
||||
|
||||
for (size_t i = 0; i < m_arguments.size(); i++)
|
||||
result->m_arguments[i] = m_arguments[i]->copy();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Derived>
|
||||
void NAry<Derived>::setArgument(size_t i, ExpressionPointer expression)
|
||||
{
|
||||
@ -273,6 +291,21 @@ inline ExpressionPointer NAry<Derived>::simplified()
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Derived>
|
||||
inline ExpressionPointer NAry<Derived>::disjunctionNormalized()
|
||||
{
|
||||
for (size_t i = 0; i < m_arguments.size(); i++)
|
||||
{
|
||||
BOOST_ASSERT(m_arguments[i]);
|
||||
|
||||
m_arguments[i] = m_arguments[i]->disjunctionNormalized();
|
||||
}
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Derived>
|
||||
inline void NAry<Derived>::print(std::ostream &ostream) const
|
||||
{
|
||||
|
@ -29,6 +29,8 @@ class Not: public ExpressionCRTP<Not>
|
||||
public:
|
||||
Not();
|
||||
|
||||
ExpressionPointer copy() override;
|
||||
|
||||
void setArgument(ExpressionPointer argument);
|
||||
ExpressionPointer argument() const;
|
||||
|
||||
@ -36,6 +38,7 @@ class Not: public ExpressionCRTP<Not>
|
||||
ExpressionPointer negationNormalized() override;
|
||||
ExpressionPointer prenex(Expression::Type lastExpressionType) override;
|
||||
ExpressionPointer simplified() override;
|
||||
ExpressionPointer disjunctionNormalized() override;
|
||||
|
||||
void print(std::ostream &ostream) const override;
|
||||
|
||||
|
@ -49,10 +49,13 @@ class QuantifiedCRTP: public Quantified
|
||||
return Derived::ExpressionType;
|
||||
}
|
||||
|
||||
ExpressionPointer copy() override;
|
||||
|
||||
ExpressionPointer reduced() override;
|
||||
ExpressionPointer negationNormalized() override;
|
||||
ExpressionPointer prenex(Expression::Type lastExpressionType) override;
|
||||
ExpressionPointer simplified() override;
|
||||
ExpressionPointer disjunctionNormalized() override;
|
||||
|
||||
void print(std::ostream &ostream) const override;
|
||||
};
|
||||
@ -98,6 +101,18 @@ boost::intrusive_ptr<Derived> QuantifiedCRTP<Derived>::parse(Context &context,
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Derived>
|
||||
ExpressionPointer QuantifiedCRTP<Derived>::copy()
|
||||
{
|
||||
auto result = new Derived;
|
||||
|
||||
result->m_argument = m_argument->copy();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline void Quantified::setArgument(ExpressionPointer expression)
|
||||
{
|
||||
m_argument = expression;
|
||||
@ -190,6 +205,18 @@ inline ExpressionPointer QuantifiedCRTP<Derived>::simplified()
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Derived>
|
||||
inline ExpressionPointer QuantifiedCRTP<Derived>::disjunctionNormalized()
|
||||
{
|
||||
BOOST_ASSERT(m_argument);
|
||||
|
||||
m_argument = m_argument->disjunctionNormalized();
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<class Derived>
|
||||
inline void QuantifiedCRTP<Derived>::print(std::ostream &ostream) const
|
||||
{
|
||||
|
@ -26,9 +26,16 @@ namespace pddl
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
ExpressionPointer Expression::copy()
|
||||
{
|
||||
return this;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
ExpressionPointer Expression::normalized()
|
||||
{
|
||||
return reduced()->negationNormalized()->prenex()->simplified();
|
||||
return reduced()->negationNormalized()->prenex()->simplified()->disjunctionNormalized()->simplified();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -97,6 +104,13 @@ ExpressionPointer Expression::simplified()
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
ExpressionPointer Expression::disjunctionNormalized()
|
||||
{
|
||||
return this;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
ExpressionPointer Expression::negated()
|
||||
{
|
||||
if (expressionType() == Type::Not)
|
||||
|
@ -3,6 +3,8 @@
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
|
||||
#include <plasp/pddl/expressions/Or.h>
|
||||
|
||||
namespace plasp
|
||||
{
|
||||
namespace pddl
|
||||
@ -20,6 +22,51 @@ const std::string And::Identifier = "and";
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
ExpressionPointer And::disjunctionNormalized()
|
||||
{
|
||||
for (size_t i = 0; i < m_arguments.size(); i++)
|
||||
{
|
||||
BOOST_ASSERT(m_arguments[i]);
|
||||
|
||||
m_arguments[i] = m_arguments[i]->disjunctionNormalized();
|
||||
}
|
||||
|
||||
const auto match = std::find_if(m_arguments.begin(), m_arguments.end(),
|
||||
[](const auto &argument)
|
||||
{
|
||||
return argument->expressionType() == Expression::Type::Or;
|
||||
});
|
||||
|
||||
if (match == m_arguments.end())
|
||||
return this;
|
||||
|
||||
auto orExpression = OrPointer(dynamic_cast<expressions::Or *>(match->get()));
|
||||
const size_t orExpressionIndex = match - m_arguments.begin();
|
||||
|
||||
// Apply the distributive law
|
||||
// Copy this and expression for each argument of the or expression
|
||||
for (size_t i = 0; i < orExpression->arguments().size(); i++)
|
||||
{
|
||||
auto newAndExpression = new expressions::And;
|
||||
newAndExpression->arguments().resize(m_arguments.size());
|
||||
|
||||
for (size_t j = 0; j < m_arguments.size(); j++)
|
||||
{
|
||||
if (j == orExpressionIndex)
|
||||
newAndExpression->arguments()[j] = orExpression->arguments()[i]->copy();
|
||||
else
|
||||
newAndExpression->arguments()[j] = m_arguments[j]->copy();
|
||||
}
|
||||
|
||||
// Replace the respective argument with the new, recursively normalized and expression
|
||||
orExpression->arguments()[i] = newAndExpression->disjunctionNormalized();
|
||||
}
|
||||
|
||||
return orExpression;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -22,6 +22,17 @@ At::At()
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
ExpressionPointer At::copy()
|
||||
{
|
||||
auto result = new At;
|
||||
|
||||
result->m_argument = m_argument->copy();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void At::setArgument(ExpressionPointer argument)
|
||||
{
|
||||
m_argument = argument;
|
||||
@ -76,6 +87,17 @@ ExpressionPointer At::simplified()
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
ExpressionPointer At::disjunctionNormalized()
|
||||
{
|
||||
BOOST_ASSERT(m_argument);
|
||||
|
||||
m_argument = m_argument->disjunctionNormalized();
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void At::print(std::ostream &ostream) const
|
||||
{
|
||||
ostream << "(at " << m_timePoint << " ";
|
||||
|
@ -25,6 +25,17 @@ Not::Not()
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
ExpressionPointer Not::copy()
|
||||
{
|
||||
auto result = new Not;
|
||||
|
||||
result->m_argument = m_argument->copy();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void Not::setArgument(ExpressionPointer argument)
|
||||
{
|
||||
m_argument = argument;
|
||||
@ -150,6 +161,17 @@ ExpressionPointer Not::simplified()
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
ExpressionPointer Not::disjunctionNormalized()
|
||||
{
|
||||
BOOST_ASSERT(m_argument);
|
||||
|
||||
m_argument = m_argument->disjunctionNormalized();
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void Not::print(std::ostream &ostream) const
|
||||
{
|
||||
ostream << "(not ";
|
||||
|
@ -302,3 +302,51 @@ TEST(PDDLNormalizationTests, PrenexGroupSameType)
|
||||
|
||||
ASSERT_EQ(output.str(), "(forall (?v1 ?v2 ?v6 ?v7) (exists (?v3 ?v8) (forall (?v4 ?v9) (exists (?v5) (and (a) (b))))))");
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(PDDLNormalizationTests, DisjunctiveNormalForm)
|
||||
{
|
||||
auto f = expressions::ForAllPointer(new expressions::ForAll);
|
||||
auto e = expressions::ExistsPointer(new expressions::Exists);
|
||||
auto a = expressions::AndPointer(new expressions::And);
|
||||
auto o1 = expressions::OrPointer(new expressions::Or);
|
||||
auto o2 = expressions::OrPointer(new expressions::Or);
|
||||
auto o3 = expressions::OrPointer(new expressions::Or);
|
||||
|
||||
f->variables() = {new expressions::Variable("v1")};
|
||||
f->setArgument(e);
|
||||
|
||||
e->variables() = {new expressions::Variable("v2")};
|
||||
e->setArgument(o1);
|
||||
|
||||
o1->addArgument(a);
|
||||
o1->addArgument(new expressions::Dummy("h"));
|
||||
|
||||
a->addArgument(new expressions::Dummy("a"));
|
||||
a->addArgument(new expressions::Dummy("b"));
|
||||
a->addArgument(o2);
|
||||
a->addArgument(o3);
|
||||
|
||||
o2->addArgument(new expressions::Dummy("c"));
|
||||
o2->addArgument(new expressions::Dummy("d"));
|
||||
o2->addArgument(new expressions::Dummy("e"));
|
||||
|
||||
o3->addArgument(new expressions::Dummy("f"));
|
||||
o3->addArgument(new expressions::Dummy("g"));
|
||||
|
||||
auto normalized = f->normalized();
|
||||
|
||||
std::stringstream output;
|
||||
normalized->print(output);
|
||||
|
||||
ASSERT_EQ(output.str(), "(forall (?v1) (exists (?v2) (or "
|
||||
"(and (a) (b) (c) (f)) "
|
||||
"(h) "
|
||||
"(and (a) (b) (d) (f)) "
|
||||
"(and (a) (b) (e) (f)) "
|
||||
"(and (a) (b) (c) (g)) "
|
||||
"(and (a) (b) (d) (g)) "
|
||||
"(and (a) (b) (e) (g))"
|
||||
")))");
|
||||
}
|
||||
|
Reference in New Issue
Block a user