Passes that use Matcher

  • CPUFusion (GraphRewrite)

  • CoreFusion (GraphRewrite)

  • ReshapeElimination (GraphRewrite)

  • AlgebraicSimplification

  • CPUPostLayoutOptimizations (GraphRewrite)

  • CPURnnMatFusion

  • and many more…

Register simplify_neg handler

static std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>>
         initialize_const_values_to_ops()
     {
         return std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>>({
             {TI(op::Add), simplify_add},
             {TI(op::Multiply), simplify_multiply},
             {TI(op::Sum), simplify_sum},
             {TI(op::Negative), simplify_neg}
         });
     }

Add a fusion

max(0, A) = Relu(A)

Pattern for capturing

image11

max(0, A) = Relu(A)

namespace ngraph
 {
     namespace pass
     {
         class CoreFusion;
     }
 }

 class ngraph::pass::CoreFusion : public ngraph::pass::GraphRewrite
 {
 public:
     CoreFusion()
         : GraphRewrite()
     {
         construct_relu_pattern();
     }

     //this should go in a cpp file.
     void construct_relu_pattern()
     {
         auto iconst0 = ngraph::make_zero(element::i32, Shape{});
         auto val = make_shared(iconst0);
         auto zero = make_shared(iconst0, nullptr, NodeVector{iconst0});

         auto broadcast_pred = [](std::shared_ptr n) {
             return static_cast(std::dynamic_pointer_cast(n));
         };
         auto skip_broadcast = std::make_shared(zero, broadcast_pred);
         auto max = make_shared(skip_broadcast, val);

     pattern::graph_rewrite_callback callback = [val, zero](pattern::Matcher& m) {
             NGRAPH_DEBUG << "In a callback for construct_relu_pattern against "
                         << m.get_match_root()->get_name();

             auto pattern_map = m.get_pattern_map();
             auto mzero = m.get_pattern_map()[zero];
             if (!ngraph::is_zero(mzero))
             {
                 NGRAPH_DEBUG << "zero constant = " << mzero->get_name() << " not equal to 0n";
                 return false;
             }
             auto mpattern = m.get_match_root();

             auto cg = shared_ptr(new op::Relu(pattern_map[val]));
             ngraph::replace_node(m.get_match_root(), cg);
             return true;
         };

         auto m = make_shared(max, callback);
         this->add_matcher(m);
     }
 };

Recurrent patterns

Equivalent to "A(BC)+A" in regexes

(((A + 0) + 0) + 0) = A

image12

image13

Shape shape{};
 auto a = make_shared<op::Parameter>(element::i32, shape);
 auto b = make_shared<op::Parameter>(element::i32, shape);
 auto rpattern = std::make_shared<pattern::op::Label>(b);
 auto iconst0 = ngraph::make_zero(element::i32, shape);
 auto abs = make_shared<op::Abs>(a);
 auto add1 = iconst0 + b;
 auto add2 = iconst0 + add1;
 auto add3 = iconst0 + add2;
 auto padd = iconst0 + rpattern;
 std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
 RecurrentMatcher rm(padd, rpattern, empty_correlated_matches, nullptr);
 ASSERT_TRUE(rm.match(add3));