Implemented disjunctive normal form.

This commit is contained in:
Patrick Lühne 2016-09-08 03:42:32 +02:00
parent 34496a7158
commit 4fc9b35c13
12 changed files with 273 additions and 1 deletions

View File

@ -120,11 +120,14 @@ class Expression
virtual Type expressionType() const = 0; virtual Type expressionType() const = 0;
virtual ExpressionPointer copy();
ExpressionPointer normalized(); ExpressionPointer normalized();
virtual ExpressionPointer reduced(); virtual ExpressionPointer reduced();
virtual ExpressionPointer negationNormalized(); virtual ExpressionPointer negationNormalized();
virtual ExpressionPointer prenex(Expression::Type lastQuantifierType = Expression::Type::Exists); virtual ExpressionPointer prenex(Expression::Type lastQuantifierType = Expression::Type::Exists);
virtual ExpressionPointer simplified(); virtual ExpressionPointer simplified();
virtual ExpressionPointer disjunctionNormalized();
ExpressionPointer negated(); ExpressionPointer negated();
virtual void print(std::ostream &ostream) const = 0; virtual void print(std::ostream &ostream) const = 0;

View File

@ -22,6 +22,9 @@ class And: public NAry<And>
static const Expression::Type ExpressionType = Expression::Type::And; static const Expression::Type ExpressionType = Expression::Type::And;
static const std::string Identifier; static const std::string Identifier;
public:
ExpressionPointer disjunctionNormalized() override;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -32,6 +32,8 @@ class At: public ExpressionCRTP<At>
public: public:
At(); At();
ExpressionPointer copy() override;
size_t timePoint() const; size_t timePoint() const;
void setArgument(ExpressionPointer argument); void setArgument(ExpressionPointer argument);
@ -41,6 +43,7 @@ class At: public ExpressionCRTP<At>
ExpressionPointer negationNormalized() override; ExpressionPointer negationNormalized() override;
ExpressionPointer prenex(Expression::Type lastExpressionType) override; ExpressionPointer prenex(Expression::Type lastExpressionType) override;
ExpressionPointer simplified() override; ExpressionPointer simplified() override;
ExpressionPointer disjunctionNormalized() override;
void print(std::ostream &ostream) const override; void print(std::ostream &ostream) const override;

View File

@ -29,12 +29,16 @@ class Binary: public ExpressionCRTP<Derived>
ExpressionContext &expressionContext, ExpressionParser parseExpression); ExpressionContext &expressionContext, ExpressionParser parseExpression);
public: public:
ExpressionPointer copy() override;
void setArgument(size_t i, ExpressionPointer argument); void setArgument(size_t i, ExpressionPointer argument);
const std::array<ExpressionPointer, 2> &arguments() const; const std::array<ExpressionPointer, 2> &arguments() const;
ExpressionPointer reduced() override; ExpressionPointer reduced() override;
ExpressionPointer negationNormalized() override; ExpressionPointer negationNormalized() override;
ExpressionPointer prenex(Expression::Type lastExpressionType) override; ExpressionPointer prenex(Expression::Type lastExpressionType) override;
ExpressionPointer simplified() override;
ExpressionPointer disjunctionNormalized() override;
void print(std::ostream &ostream) const override; void print(std::ostream &ostream) const override;
@ -73,6 +77,19 @@ boost::intrusive_ptr<Derived> Binary<Derived>::parse(Context &context,
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<class Derived>
ExpressionPointer Binary<Derived>::copy()
{
auto result = new Derived;
for (size_t i = 0; i < m_arguments.size(); i++)
result->m_arguments[i] = m_arguments[i]->copy();
return result;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<class Derived> template<class Derived>
void Binary<Derived>::setArgument(size_t i, ExpressionPointer expression) void Binary<Derived>::setArgument(size_t i, ExpressionPointer expression)
{ {
@ -130,6 +147,36 @@ inline ExpressionPointer Binary<Derived>::prenex(Expression::Type)
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<class Derived>
inline ExpressionPointer Binary<Derived>::simplified()
{
for (size_t i = 0; i < m_arguments.size(); i++)
{
BOOST_ASSERT(m_arguments[i]);
m_arguments[i] = m_arguments[i]->simplified();
}
return this;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<class Derived>
inline ExpressionPointer Binary<Derived>::disjunctionNormalized()
{
for (size_t i = 0; i < m_arguments.size(); i++)
{
BOOST_ASSERT(m_arguments[i]);
m_arguments[i] = m_arguments[i]->disjunctionNormalized();
}
return this;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<class Derived> template<class Derived>
inline void Binary<Derived>::print(std::ostream &ostream) const inline void Binary<Derived>::print(std::ostream &ostream) const
{ {

View File

@ -29,6 +29,8 @@ class NAry: public ExpressionCRTP<Derived>
ExpressionContext &expressionContext, ExpressionParser parseExpression); ExpressionContext &expressionContext, ExpressionParser parseExpression);
public: public:
ExpressionPointer copy() override;
void setArgument(size_t i, ExpressionPointer argument); void setArgument(size_t i, ExpressionPointer argument);
void addArgument(ExpressionPointer argument); void addArgument(ExpressionPointer argument);
Expressions &arguments(); Expressions &arguments();
@ -38,6 +40,7 @@ class NAry: public ExpressionCRTP<Derived>
ExpressionPointer negationNormalized() override; ExpressionPointer negationNormalized() override;
ExpressionPointer prenex(Expression::Type lastExpressionType) override; ExpressionPointer prenex(Expression::Type lastExpressionType) override;
ExpressionPointer simplified() override; ExpressionPointer simplified() override;
ExpressionPointer disjunctionNormalized() override;
void print(std::ostream &ostream) const override; void print(std::ostream &ostream) const override;
@ -85,6 +88,21 @@ boost::intrusive_ptr<Derived> NAry<Derived>::parse(Context &context,
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<class Derived>
ExpressionPointer NAry<Derived>::copy()
{
auto result = new Derived;
result->m_arguments.resize(m_arguments.size());
for (size_t i = 0; i < m_arguments.size(); i++)
result->m_arguments[i] = m_arguments[i]->copy();
return result;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<class Derived> template<class Derived>
void NAry<Derived>::setArgument(size_t i, ExpressionPointer expression) void NAry<Derived>::setArgument(size_t i, ExpressionPointer expression)
{ {
@ -273,6 +291,21 @@ inline ExpressionPointer NAry<Derived>::simplified()
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<class Derived>
inline ExpressionPointer NAry<Derived>::disjunctionNormalized()
{
for (size_t i = 0; i < m_arguments.size(); i++)
{
BOOST_ASSERT(m_arguments[i]);
m_arguments[i] = m_arguments[i]->disjunctionNormalized();
}
return this;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<class Derived> template<class Derived>
inline void NAry<Derived>::print(std::ostream &ostream) const inline void NAry<Derived>::print(std::ostream &ostream) const
{ {

View File

@ -29,6 +29,8 @@ class Not: public ExpressionCRTP<Not>
public: public:
Not(); Not();
ExpressionPointer copy() override;
void setArgument(ExpressionPointer argument); void setArgument(ExpressionPointer argument);
ExpressionPointer argument() const; ExpressionPointer argument() const;
@ -36,6 +38,7 @@ class Not: public ExpressionCRTP<Not>
ExpressionPointer negationNormalized() override; ExpressionPointer negationNormalized() override;
ExpressionPointer prenex(Expression::Type lastExpressionType) override; ExpressionPointer prenex(Expression::Type lastExpressionType) override;
ExpressionPointer simplified() override; ExpressionPointer simplified() override;
ExpressionPointer disjunctionNormalized() override;
void print(std::ostream &ostream) const override; void print(std::ostream &ostream) const override;

View File

@ -49,10 +49,13 @@ class QuantifiedCRTP: public Quantified
return Derived::ExpressionType; return Derived::ExpressionType;
} }
ExpressionPointer copy() override;
ExpressionPointer reduced() override; ExpressionPointer reduced() override;
ExpressionPointer negationNormalized() override; ExpressionPointer negationNormalized() override;
ExpressionPointer prenex(Expression::Type lastExpressionType) override; ExpressionPointer prenex(Expression::Type lastExpressionType) override;
ExpressionPointer simplified() override; ExpressionPointer simplified() override;
ExpressionPointer disjunctionNormalized() override;
void print(std::ostream &ostream) const override; void print(std::ostream &ostream) const override;
}; };
@ -98,6 +101,18 @@ boost::intrusive_ptr<Derived> QuantifiedCRTP<Derived>::parse(Context &context,
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<class Derived>
ExpressionPointer QuantifiedCRTP<Derived>::copy()
{
auto result = new Derived;
result->m_argument = m_argument->copy();
return result;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline void Quantified::setArgument(ExpressionPointer expression) inline void Quantified::setArgument(ExpressionPointer expression)
{ {
m_argument = expression; m_argument = expression;
@ -190,6 +205,18 @@ inline ExpressionPointer QuantifiedCRTP<Derived>::simplified()
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<class Derived>
inline ExpressionPointer QuantifiedCRTP<Derived>::disjunctionNormalized()
{
BOOST_ASSERT(m_argument);
m_argument = m_argument->disjunctionNormalized();
return this;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<class Derived> template<class Derived>
inline void QuantifiedCRTP<Derived>::print(std::ostream &ostream) const inline void QuantifiedCRTP<Derived>::print(std::ostream &ostream) const
{ {

View File

@ -26,9 +26,16 @@ namespace pddl
// //
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
ExpressionPointer Expression::copy()
{
return this;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
ExpressionPointer Expression::normalized() ExpressionPointer Expression::normalized()
{ {
return reduced()->negationNormalized()->prenex()->simplified(); return reduced()->negationNormalized()->prenex()->simplified()->disjunctionNormalized()->simplified();
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
@ -97,6 +104,13 @@ ExpressionPointer Expression::simplified()
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
ExpressionPointer Expression::disjunctionNormalized()
{
return this;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
ExpressionPointer Expression::negated() ExpressionPointer Expression::negated()
{ {
if (expressionType() == Type::Not) if (expressionType() == Type::Not)

View File

@ -3,6 +3,8 @@
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <plasp/pddl/expressions/Or.h>
namespace plasp namespace plasp
{ {
namespace pddl namespace pddl
@ -20,6 +22,51 @@ const std::string And::Identifier = "and";
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
ExpressionPointer And::disjunctionNormalized()
{
for (size_t i = 0; i < m_arguments.size(); i++)
{
BOOST_ASSERT(m_arguments[i]);
m_arguments[i] = m_arguments[i]->disjunctionNormalized();
}
const auto match = std::find_if(m_arguments.begin(), m_arguments.end(),
[](const auto &argument)
{
return argument->expressionType() == Expression::Type::Or;
});
if (match == m_arguments.end())
return this;
auto orExpression = OrPointer(dynamic_cast<expressions::Or *>(match->get()));
const size_t orExpressionIndex = match - m_arguments.begin();
// Apply the distributive law
// Copy this and expression for each argument of the or expression
for (size_t i = 0; i < orExpression->arguments().size(); i++)
{
auto newAndExpression = new expressions::And;
newAndExpression->arguments().resize(m_arguments.size());
for (size_t j = 0; j < m_arguments.size(); j++)
{
if (j == orExpressionIndex)
newAndExpression->arguments()[j] = orExpression->arguments()[i]->copy();
else
newAndExpression->arguments()[j] = m_arguments[j]->copy();
}
// Replace the respective argument with the new, recursively normalized and expression
orExpression->arguments()[i] = newAndExpression->disjunctionNormalized();
}
return orExpression;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} }
} }
} }

View File

@ -22,6 +22,17 @@ At::At()
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
ExpressionPointer At::copy()
{
auto result = new At;
result->m_argument = m_argument->copy();
return result;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void At::setArgument(ExpressionPointer argument) void At::setArgument(ExpressionPointer argument)
{ {
m_argument = argument; m_argument = argument;
@ -76,6 +87,17 @@ ExpressionPointer At::simplified()
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
ExpressionPointer At::disjunctionNormalized()
{
BOOST_ASSERT(m_argument);
m_argument = m_argument->disjunctionNormalized();
return this;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void At::print(std::ostream &ostream) const void At::print(std::ostream &ostream) const
{ {
ostream << "(at " << m_timePoint << " "; ostream << "(at " << m_timePoint << " ";

View File

@ -25,6 +25,17 @@ Not::Not()
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
ExpressionPointer Not::copy()
{
auto result = new Not;
result->m_argument = m_argument->copy();
return result;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void Not::setArgument(ExpressionPointer argument) void Not::setArgument(ExpressionPointer argument)
{ {
m_argument = argument; m_argument = argument;
@ -150,6 +161,17 @@ ExpressionPointer Not::simplified()
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
ExpressionPointer Not::disjunctionNormalized()
{
BOOST_ASSERT(m_argument);
m_argument = m_argument->disjunctionNormalized();
return this;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void Not::print(std::ostream &ostream) const void Not::print(std::ostream &ostream) const
{ {
ostream << "(not "; ostream << "(not ";

View File

@ -302,3 +302,51 @@ TEST(PDDLNormalizationTests, PrenexGroupSameType)
ASSERT_EQ(output.str(), "(forall (?v1 ?v2 ?v6 ?v7) (exists (?v3 ?v8) (forall (?v4 ?v9) (exists (?v5) (and (a) (b))))))"); ASSERT_EQ(output.str(), "(forall (?v1 ?v2 ?v6 ?v7) (exists (?v3 ?v8) (forall (?v4 ?v9) (exists (?v5) (and (a) (b))))))");
} }
////////////////////////////////////////////////////////////////////////////////////////////////////
TEST(PDDLNormalizationTests, DisjunctiveNormalForm)
{
auto f = expressions::ForAllPointer(new expressions::ForAll);
auto e = expressions::ExistsPointer(new expressions::Exists);
auto a = expressions::AndPointer(new expressions::And);
auto o1 = expressions::OrPointer(new expressions::Or);
auto o2 = expressions::OrPointer(new expressions::Or);
auto o3 = expressions::OrPointer(new expressions::Or);
f->variables() = {new expressions::Variable("v1")};
f->setArgument(e);
e->variables() = {new expressions::Variable("v2")};
e->setArgument(o1);
o1->addArgument(a);
o1->addArgument(new expressions::Dummy("h"));
a->addArgument(new expressions::Dummy("a"));
a->addArgument(new expressions::Dummy("b"));
a->addArgument(o2);
a->addArgument(o3);
o2->addArgument(new expressions::Dummy("c"));
o2->addArgument(new expressions::Dummy("d"));
o2->addArgument(new expressions::Dummy("e"));
o3->addArgument(new expressions::Dummy("f"));
o3->addArgument(new expressions::Dummy("g"));
auto normalized = f->normalized();
std::stringstream output;
normalized->print(output);
ASSERT_EQ(output.str(), "(forall (?v1) (exists (?v2) (or "
"(and (a) (b) (c) (f)) "
"(h) "
"(and (a) (b) (d) (f)) "
"(and (a) (b) (e) (f)) "
"(and (a) (b) (c) (g)) "
"(and (a) (b) (d) (g)) "
"(and (a) (b) (e) (g))"
")))");
}