diff --git a/analyzer/server.go b/analyzer/server.go index afa21b5e14..9fac92a45c 100644 --- a/analyzer/server.go +++ b/analyzer/server.go @@ -333,8 +333,10 @@ func NewServerFromConfig() (*Server, error) { tr.AddTraversalExtension(ge.NewSocketsTraversalExtension()) tr.AddTraversalExtension(ge.NewDescendantsTraversalExtension()) tr.AddTraversalExtension(ge.NewAscendantsTraversalExtension()) + tr.AddTraversalExtension(ge.NewNeighborsTraversalExtension()) tr.AddTraversalExtension(ge.NewNextHopTraversalExtension()) tr.AddTraversalExtension(ge.NewGroupTraversalExtension()) + tr.AddTraversalExtension(ge.NewMergeTraversalExtension()) // new flow subscriber endpoints flowSubscriberWSServer := ws.NewStructServer(config.NewWSServer(hub.HTTPServer(), "/ws/subscriber/flow", apiAuthBackend)) diff --git a/go.mod b/go.mod index 9b18b8f3f1..773d68e8d0 100644 --- a/go.mod +++ b/go.mod @@ -77,6 +77,7 @@ require ( github.com/spf13/cobra v1.1.1 github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.7.0 + github.com/stretchr/testify v1.7.0 github.com/t-yuki/gocover-cobertura v0.0.0-20180217150009-aaee18c8195c github.com/tebeka/go2xunit v1.4.10 github.com/tebeka/selenium v0.0.0-20170314201507-657e45ec600f diff --git a/go.sum b/go.sum index 160408838c..80faa1e473 100644 --- a/go.sum +++ b/go.sum @@ -1001,8 +1001,9 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/subosito/gotenv v1.2.0 h1:Slr1R9HxAlEKefgq5jn9U+DnETlIUa6HfgEzj0g5d7s= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/syndtr/gocapability v0.0.0-20160928074757-e7cb7fa329f4/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= diff --git a/graffiti/go.mod b/graffiti/go.mod index 647ff380f2..e24a38c5ab 100644 --- a/graffiti/go.mod +++ b/graffiti/go.mod @@ -40,6 +40,7 @@ require ( github.com/skydive-project/go-debouncer v1.0.0 github.com/spf13/cast v1.3.1 github.com/spf13/cobra v1.1.1 + github.com/stretchr/testify v1.7.0 github.com/tchap/zapext v1.0.0 github.com/xeipuuv/gojsonschema v1.2.0 go.uber.org/zap v1.16.0 diff --git a/graffiti/graph/cachedbackend.go b/graffiti/graph/cachedbackend.go index 72d109b833..44d429064b 100644 --- a/graffiti/graph/cachedbackend.go +++ b/graffiti/graph/cachedbackend.go @@ -99,6 +99,17 @@ func (c *CachedBackend) GetNode(i Identifier, t Context) []*Node { return c.persistent.GetNode(i, t) } +// GetNodesFromIDs retrieve the list of nodes for the list of identifiers from the cache within a time slice +func (c *CachedBackend) GetNodesFromIDs(i []Identifier, t Context) []*Node { + mode := c.cacheMode.Load() + + if t.TimeSlice == nil || mode == CacheOnlyMode || c.persistent == nil { + return c.memory.GetNodesFromIDs(i, t) + } + + return c.persistent.GetNodesFromIDs(i, t) +} + // GetNodeEdges retrieve a list of edges from a node within a time slice, matching metadata func (c *CachedBackend) GetNodeEdges(n *Node, t Context, m ElementMatcher) (edges []*Edge) { mode := c.cacheMode.Load() @@ -110,6 +121,17 @@ func (c *CachedBackend) GetNodeEdges(n *Node, t Context, m ElementMatcher) (edge return c.persistent.GetNodeEdges(n, t, m) } +// GetNodesEdges return the list of all edges for a list of nodes within time slice, matching metadata +func (c *CachedBackend) GetNodesEdges(n []*Node, t Context, m ElementMatcher) (edges []*Edge) { + mode := c.cacheMode.Load() + + if t.TimeSlice == nil || mode == CacheOnlyMode || c.persistent == nil { + return c.memory.GetNodesEdges(n, t, m) + } + + return c.persistent.GetNodesEdges(n, t, m) +} + // EdgeAdded add an edge in the cache func (c *CachedBackend) EdgeAdded(e *Edge) error { mode := c.cacheMode.Load() diff --git a/graffiti/graph/elasticsearch.go b/graffiti/graph/elasticsearch.go index 3fe281b428..f0aee4f711 100644 --- a/graffiti/graph/elasticsearch.go +++ b/graffiti/graph/elasticsearch.go @@ -87,6 +87,8 @@ const graphElementMapping = ` const ( nodeType = "node" edgeType = "edge" + // maxClauseCount limit the number of clauses in one query to ES + maxClauseCount = 512 ) // ElasticSearchBackend describes a persistent backend based on ElasticSearch @@ -242,6 +244,44 @@ func (b *ElasticSearchBackend) GetNode(i Identifier, t Context) []*Node { return nodes } +// GetNodesFromIDs get the list of nodes for the list of identifiers within a time slice +func (b *ElasticSearchBackend) GetNodesFromIDs(identifiersList []Identifier, t Context) []*Node { + if len(identifiersList) == 0 { + return []*Node{} + } + + // ES default max number of clauses is set by default to 1024 + // https://www.elastic.co/guide/en/elasticsearch/reference/current/search-settings.html + // Group queries in a maximum of half of the max. + // Other filters (time), will be also in the query. + identifiersBatch := batchIdentifiers(identifiersList, maxClauseCount) + + nodes := []*Node{} + + for _, idList := range identifiersBatch { + identifiersFilter := []*filters.Filter{} + for _, i := range idList { + identifiersFilter = append(identifiersFilter, filters.NewTermStringFilter("ID", string(i))) + } + identifiersORFilter := filters.NewOrFilter(identifiersFilter...) + + nodes = append(nodes, b.searchNodes(&TimedSearchQuery{ + SearchQuery: filters.SearchQuery{ + Filter: identifiersORFilter, + Sort: true, + SortBy: "Revision", + }, + TimeFilter: getTimeFilter(t.TimeSlice), + })...) + } + + if len(nodes) > 1 && t.TimePoint { + return []*Node{nodes[len(nodes)-1]} + } + + return nodes +} + func (b *ElasticSearchBackend) indexEdge(e *Edge) error { raw, err := edgeToRaw(e) if err != nil { @@ -506,6 +546,53 @@ func (b *ElasticSearchBackend) GetNodeEdges(n *Node, t Context, m ElementMatcher return } +// GetNodesEdges return the list of all edges for a list of nodes within time slice +func (b *ElasticSearchBackend) GetNodesEdges(nodeList []*Node, t Context, m ElementMatcher) (edges []*Edge) { + if len(nodeList) == 0 { + return []*Edge{} + } + + // See comment at GetNodesFromIDs + // As we are adding two operations per item, make small batches + nodesBatch := batchNodes(nodeList, maxClauseCount/2) + + for _, nList := range nodesBatch { + var filter *filters.Filter + if m != nil { + f, err := m.Filter() + if err != nil { + return []*Edge{} + } + filter = f + } + + var searchQuery filters.SearchQuery + if !t.TimePoint { + searchQuery = filters.SearchQuery{Sort: true, SortBy: "UpdatedAt"} + } + + nodesFilter := []*filters.Filter{} + for _, n := range nList { + nodesFilter = append(nodesFilter, filters.NewTermStringFilter("Parent", string(n.ID))) + nodesFilter = append(nodesFilter, filters.NewTermStringFilter("Child", string(n.ID))) + } + searchQuery.Filter = filters.NewOrFilter(nodesFilter...) + + edges = append(edges, b.searchEdges(&TimedSearchQuery{ + SearchQuery: searchQuery, + TimeFilter: getTimeFilter(t.TimeSlice), + ElementFilter: filter, + })...) + + } + + if len(edges) > 1 && t.TimePoint { + edges = dedupEdges(edges) + } + + return +} + // IsHistorySupported returns that this backend does support history func (b *ElasticSearchBackend) IsHistorySupported() bool { return true @@ -647,3 +734,26 @@ func NewElasticSearchBackendFromConfig(cfg es.Config, extraDynamicTemplates map[ return newElasticSearchBackendFromClient(client, cfg.IndexPrefix, liveIndex, archiveIndex, logger), nil } + +func batchNodes(items []*Node, batchSize int) [][]*Node { + batches := make([][]*Node, 0, (len(items)+batchSize-1)/batchSize) + + for batchSize < len(items) { + items, batches = items[batchSize:], append(batches, items[0:batchSize:batchSize]) + } + batches = append(batches, items) + + return batches + +} + +func batchIdentifiers(items []Identifier, batchSize int) [][]Identifier { + batches := make([][]Identifier, 0, (len(items)+batchSize-1)/batchSize) + + for batchSize < len(items) { + items, batches = items[batchSize:], append(batches, items[0:batchSize:batchSize]) + } + batches = append(batches, items) + + return batches +} diff --git a/graffiti/graph/graph.go b/graffiti/graph/graph.go index c263fc9b95..291f59a30a 100644 --- a/graffiti/graph/graph.go +++ b/graffiti/graph/graph.go @@ -111,7 +111,9 @@ type Backend interface { NodeAdded(n *Node) error NodeDeleted(n *Node) error GetNode(i Identifier, at Context) []*Node + GetNodesFromIDs(i []Identifier, at Context) []*Node GetNodeEdges(n *Node, at Context, m ElementMatcher) []*Edge + GetNodesEdges(n []*Node, at Context, m ElementMatcher) []*Edge EdgeAdded(e *Edge) error EdgeDeleted(e *Edge) error @@ -573,6 +575,21 @@ func (n *Node) String() string { return string(b) } +func (n *Node) Copy() *Node { + return &Node{ + graphElement: graphElement{ + ID: n.ID, + Host: n.Host, + Origin: n.Origin, + CreatedAt: n.CreatedAt, + UpdatedAt: n.UpdatedAt, + DeletedAt: n.DeletedAt, + Revision: n.Revision, + Metadata: n.Metadata.Copy(), + }, + } +} + // UnmarshalJSON custom unmarshal function func (n *Node) UnmarshalJSON(b []byte) error { // wrapper to avoid unmarshal infinite loop @@ -1185,6 +1202,14 @@ func (g *Graph) GetNode(i Identifier) *Node { return nil } +// GetNodesFromIDs returns a list of nodes from a list of identifiers +func (g *Graph) GetNodesFromIDs(i []Identifier) []*Node { + if len(i) == 0 { + return []*Node{} + } + return g.backend.GetNodesFromIDs(i, g.context) +} + // CreateNode returns a new node not bound to a graph func CreateNode(i Identifier, m Metadata, t Time, h string, o string) *Node { n := &Node{ @@ -1365,6 +1390,14 @@ func (g *Graph) GetNodeEdges(n *Node, m ElementMatcher) []*Edge { return g.backend.GetNodeEdges(n, g.context, m) } +// GetNodesEdges returns the list with all edges for a list of nodes +func (g *Graph) GetNodesEdges(n []*Node, m ElementMatcher) []*Edge { + if len(n) == 0 { + return []*Edge{} + } + return g.backend.GetNodesEdges(n, g.context, m) +} + func (g *Graph) String() string { j, _ := json.Marshal(g) return string(j) diff --git a/graffiti/graph/graph_test.go b/graffiti/graph/graph_test.go index 5aa113f3e6..326f0d0211 100644 --- a/graffiti/graph/graph_test.go +++ b/graffiti/graph/graph_test.go @@ -22,6 +22,9 @@ import ( "strconv" "strings" "testing" + "time" + + "github.com/stretchr/testify/assert" ) func newGraph(t *testing.T) *Graph { @@ -509,6 +512,47 @@ func TestEvents(t *testing.T) { } } +func TestNodeCopy(t *testing.T) { + n := &Node{ + graphElement: graphElement{ + ID: Identifier("id"), + Host: "Host", + Origin: "Origin", + CreatedAt: Time(time.Unix(100, 0)), + UpdatedAt: Time(time.Unix(200, 0)), + DeletedAt: Time(time.Unix(300, 0)), + Revision: 1, + Metadata: Metadata{"foo": "bar"}, + }, + } + + cn := n.Copy() + assert.Equal(t, n, cn) + + // Check that modifications in the copied node do not affect the origin node + ok := cn.Metadata.SetField("new", "value") + assert.Truef(t, ok, "copied node set metadata") + assert.NotEqualf(t, n, cn, "Metadata") + + cn.Host = "modified" + assert.NotEqualf(t, n, cn, "Host") + + cn.Origin = "modified" + assert.NotEqualf(t, n, cn, "Origin") + + cn.Revision = 99 + assert.NotEqualf(t, n, cn, "Revision") + + cn.CreatedAt = Time(time.Unix(100, 99)) + assert.NotEqualf(t, n, cn, "CreatedAt") + + cn.UpdatedAt = Time(time.Unix(200, 99)) + assert.NotEqualf(t, n, cn, "UpdatedAt") + + cn.DeletedAt = Time(time.Unix(300, 99)) + assert.NotEqualf(t, n, cn, "DeletedAt") +} + type FakeRecursiveListener1 struct { DefaultGraphListener graph *Graph diff --git a/graffiti/graph/memory.go b/graffiti/graph/memory.go index 66e8f16204..a626cc55da 100644 --- a/graffiti/graph/memory.go +++ b/graffiti/graph/memory.go @@ -138,6 +138,17 @@ func (m *MemoryBackend) GetNode(i Identifier, t Context) []*Node { return nil } +// GetNodesFromIDs from the graph backend +func (m *MemoryBackend) GetNodesFromIDs(identifiersList []Identifier, t Context) []*Node { + nodes := []*Node{} + for _, i := range identifiersList { + if n, ok := m.nodes[i]; ok { + nodes = append(nodes, n.Node) + } + } + return nodes +} + // GetNodeEdges returns a list of edges of a node func (m *MemoryBackend) GetNodeEdges(n *Node, t Context, meta ElementMatcher) []*Edge { edges := []*Edge{} @@ -153,6 +164,22 @@ func (m *MemoryBackend) GetNodeEdges(n *Node, t Context, meta ElementMatcher) [] return edges } +// GetNodesEdges returns the list of edges for a list of nodes +func (m *MemoryBackend) GetNodesEdges(nodeList []*Node, t Context, meta ElementMatcher) []*Edge { + edges := []*Edge{} + for _, n := range nodeList { + if n, ok := m.nodes[n.ID]; ok { + for _, e := range n.edges { + if e.MatchMetadata(meta) { + edges = append(edges, e.Edge) + } + } + } + } + + return edges +} + // EdgeDeleted in the graph backend func (m *MemoryBackend) EdgeDeleted(e *Edge) error { if _, ok := m.edges[e.ID]; !ok { diff --git a/graffiti/graph/orientdb.go b/graffiti/graph/orientdb.go index 13bf00ea41..458a74d648 100644 --- a/graffiti/graph/orientdb.go +++ b/graffiti/graph/orientdb.go @@ -222,6 +222,23 @@ func (o *OrientDBBackend) GetNode(i Identifier, t Context) (nodes []*Node) { return o.searchNodes(t, query) } +func (o *OrientDBBackend) GetNodesFromIDs(identifiersList []Identifier, t Context) (nodes []*Node) { + query := orientdb.FilterToExpression(getTimeFilter(t.TimeSlice), nil) + query += fmt.Sprintf(" AND (") + for i, id := range identifiersList { + if i == len(identifiersList)-1 { + query += fmt.Sprintf(" ID = '%s') ORDER BY Revision", id) + } else { + query += fmt.Sprintf(" ID = '%s' OR", id) + } + } + + if t.TimePoint { + query += " DESC LIMIT 1" + } + return o.searchNodes(t, query) +} + // GetNodeEdges returns a list of a node edges within time slice func (o *OrientDBBackend) GetNodeEdges(n *Node, t Context, m ElementMatcher) (edges []*Edge) { query := orientdb.FilterToExpression(getTimeFilter(t.TimeSlice), nil) @@ -232,6 +249,23 @@ func (o *OrientDBBackend) GetNodeEdges(n *Node, t Context, m ElementMatcher) (ed return o.searchEdges(t, query) } +// GetNodesEdges returns a list of a node edges within time slice +func (o *OrientDBBackend) GetNodesEdges(nodeList []*Node, t Context, m ElementMatcher) (edges []*Edge) { + query := orientdb.FilterToExpression(getTimeFilter(t.TimeSlice), nil) + query += fmt.Sprintf(" AND (") + for i, n := range nodeList { + if i == len(nodeList)-1 { + query += fmt.Sprintf(" Parent = '%s' OR Child = '%s')", n.ID, n.ID) + } else { + query += fmt.Sprintf(" Parent = '%s' OR Child = '%s' OR", n.ID, n.ID) + } + } + if matcherQuery := matcherToOrientDBSelectString(m); matcherQuery != "" { + query += " AND " + matcherQuery + } + return o.searchEdges(t, query) +} + func (o *OrientDBBackend) createEdge(e *Edge) error { fromQuery := fmt.Sprintf("SELECT FROM Node WHERE DeletedAt IS NULL AND ArchivedAt IS NULL AND ID = '%s'", e.Parent) toQuery := fmt.Sprintf("SELECT FROM Node WHERE DeletedAt IS NULL AND ArchivedAt IS NULL AND ID = '%s'", e.Child) diff --git a/graffiti/graph/traversal/traversal.go b/graffiti/graph/traversal/traversal.go index ae35b015ae..4869079532 100644 --- a/graffiti/graph/traversal/traversal.go +++ b/graffiti/graph/traversal/traversal.go @@ -1415,14 +1415,12 @@ func (tv *GraphTraversalV) SubGraph(ctx StepContext, s ...interface{}) *GraphTra // then insert edges, ignore edge insert error since one of the linked node couldn't be part // of the SubGraph - for _, n := range tv.nodes { - edges := tv.GraphTraversal.Graph.GetNodeEdges(n, nil) - for _, e := range edges { - switch err := memory.EdgeAdded(e); err { - case nil, graph.ErrParentNotFound, graph.ErrChildNotFound, graph.ErrEdgeConflict: - default: - return &GraphTraversal{error: fmt.Errorf("Error while adding edge to SubGraph: %s", err)} - } + edges := tv.GraphTraversal.Graph.GetNodesEdges(tv.nodes, nil) + for _, e := range edges { + switch err := memory.EdgeAdded(e); err { + case nil, graph.ErrParentNotFound, graph.ErrChildNotFound, graph.ErrEdgeConflict: + default: + return &GraphTraversal{error: fmt.Errorf("Error while adding edge to SubGraph: %s", err)} } } diff --git a/gremlin/traversal/merge.go b/gremlin/traversal/merge.go new file mode 100644 index 0000000000..9b21864d3c --- /dev/null +++ b/gremlin/traversal/merge.go @@ -0,0 +1,258 @@ +package traversal + +import ( + "fmt" + "reflect" + "time" + + "github.com/pkg/errors" + "github.com/skydive-project/skydive/graffiti/graph" + "github.com/skydive-project/skydive/graffiti/graph/traversal" + "github.com/skydive-project/skydive/graffiti/logging" +) + +// MergeTraversalExtension describes a new extension to enhance the topology +type MergeTraversalExtension struct { + MergeToken traversal.Token +} + +// MergeGremlinTraversalStep step aggregates elements from different revisions of the nodes into a new metadata key. +// Nodes returned by this step are copies of the nodes in the graph, not the actual nodes. +// The reason is because this step is not meant to modify nodes in the graph, just for the output. +// This step should be used with a presistant backend, so it can access previous revisions of the nodes. +// To use this step we should select a metadata key (first parameter), where the elements will be read from. +// Inside this Metadata.Key elements should have the format map[interface{}]interface{} (could be a type based on that). +// The second parameter is the metadata key where all the elements will be aggregated. +// The aggregation will with the format: map[string][]interface{}. +// All elements with the same key in the map will be joined in an slice. +// To use this step we can use a graph with a time period context, eg: G.At(1479899809,3600).V().Merge('A','B'). +// Or we can define the time period in the step: G.V().Merge('A','B',1500000000,1500099999). +// Note that in this case we define the start and end time, while in "At" is start time and duration. +// In both cases, Merge step will use the nodes given by the previous step. +type MergeGremlinTraversalStep struct { + traversal.GremlinTraversalContext + MergeKey string + MergeAggKey string + StartTime time.Time + EndTime time.Time +} + +// NewMergeTraversalExtension returns a new graph traversal extension +func NewMergeTraversalExtension() *MergeTraversalExtension { + return &MergeTraversalExtension{ + MergeToken: traversalMergeToken, + } +} + +// ScanIdent recognise the word associated with this step (in uppercase) and return a token +// which represents it. Return true if it have found a match +func (e *MergeTraversalExtension) ScanIdent(s string) (traversal.Token, bool) { + switch s { + case "MERGE": + return e.MergeToken, true + } + return traversal.IDENT, false +} + +// ParseStep generate a step for a given token, having in 'p' context and params +func (e *MergeTraversalExtension) ParseStep(t traversal.Token, p traversal.GremlinTraversalContext) (traversal.GremlinTraversalStep, error) { + switch t { + case e.MergeToken: + default: + return nil, nil + } + + var mergeKey, mergeAggKey string + var startTime, endTime time.Time + var ok bool + + switch len(p.Params) { + case 2: + mergeKey, ok = p.Params[0].(string) + if !ok { + return nil, errors.New("Merge first parameter have to be a string") + } + mergeAggKey, ok = p.Params[1].(string) + if !ok { + return nil, errors.New("Merge second parameter have to be a string") + } + case 4: + mergeKey, ok = p.Params[0].(string) + if !ok { + return nil, errors.New("Merge first parameter have to be a string") + } + mergeAggKey, ok = p.Params[1].(string) + if !ok { + return nil, errors.New("Merge second parameter have to be a string") + } + startTimeUnixEpoch, ok := p.Params[2].(int64) + if !ok { + return nil, errors.New("Merge third parameter have to be a int (unix epoch time)") + } + startTime = time.Unix(startTimeUnixEpoch, 0) + endTimeUnixEpoch, ok := p.Params[3].(int64) + if !ok { + return nil, errors.New("Merge fourth parameter have to be a int (unix epoch time)") + } + endTime = time.Unix(endTimeUnixEpoch, 0) + default: + return nil, errors.New("Merge parameter must have two or four parameters (OriginKey, DestinationKey, StartTime, EndTime)") + } + + return &MergeGremlinTraversalStep{ + GremlinTraversalContext: p, + MergeKey: mergeKey, + MergeAggKey: mergeAggKey, + StartTime: startTime, + EndTime: endTime, + }, nil +} + +// Exec executes the merge step +func (s *MergeGremlinTraversalStep) Exec(last traversal.GraphTraversalStep) (traversal.GraphTraversalStep, error) { + switch tv := last.(type) { + case *traversal.GraphTraversalV: + return s.InterfaceMerge(tv) + + } + return nil, traversal.ErrExecutionError +} + +// Reduce merge step +func (s *MergeGremlinTraversalStep) Reduce(next traversal.GremlinTraversalStep) (traversal.GremlinTraversalStep, error) { + return next, nil +} + +// Context merge step +func (s *MergeGremlinTraversalStep) Context() *traversal.GremlinTraversalContext { + return &s.GremlinTraversalContext +} + +// InterfaceMerge for each node id, group all the elements stored in Metadata.key from the +// input nodes and put them into the newest node for each id into Metadata.aggKey. +// Merge are groupped based on its key. See mergedMetadata for an example. +// All output nodes will have Metadata.aggKey defined (empty or not). +func (s *MergeGremlinTraversalStep) InterfaceMerge(tv *traversal.GraphTraversalV) (traversal.GraphTraversalStep, error) { + // If user has defined start/end time in the parameters, use that values instead of the ones comming with the graph + if !s.StartTime.IsZero() && !s.EndTime.IsZero() { + timeSlice := graph.NewTimeSlice( + graph.Time(s.StartTime).UnixMilli(), + graph.Time(s.EndTime).UnixMilli(), + ) + userTimeSliceCtx := graph.Context{ + TimeSlice: timeSlice, + TimePoint: true, + } + + newGraph, err := tv.GraphTraversal.Graph.CloneWithContext(userTimeSliceCtx) + if err != nil { + return nil, err + } + tv.GraphTraversal.Graph = newGraph + } + + tv.GraphTraversal.RLock() + defer tv.GraphTraversal.RUnlock() + + // uniqNodes store the latest node for each node identifier + uniqNodes := map[graph.Identifier]*graph.Node{} + + // elements accumulate the elements for each node id + elements := map[graph.Identifier]map[string][]interface{}{} + + // Get the list of node ids + nodesIDs := make([]graph.Identifier, 0, len(tv.GetNodes())) + for _, node := range tv.GetNodes() { + nodesIDs = append(nodesIDs, node.ID) + } + + // Get all revision for the list of node ids + revisionNodes := tv.GraphTraversal.Graph.GetNodesFromIDs(nodesIDs) + + // Store only the most recent nodes + for _, rNode := range revisionNodes { + storedNode, ok := uniqNodes[rNode.ID] + if !ok { + uniqNodes[rNode.ID] = rNode + } else { + if storedNode.Revision < rNode.Revision { + uniqNodes[rNode.ID] = rNode + } + } + + // Store elements from all revisions into the "elements" variable + elements[rNode.ID] = mergeMetadata(rNode, s.MergeKey, elements[rNode.ID]) + } + + // Move the nodes from the uniqNodes map to an slice required by TraversalV + // Return a copy of the nodes, not the actual graph nodes, because we don't want + // to modify nodes with this step, just append some extra info + nodes := []*graph.Node{} + for id, n := range uniqNodes { + nCopy := n.Copy() + + e, ok := elements[id] + if ok { + // Set the stored node with the merge of previous and current node + metadataSet := nCopy.Metadata.SetField(s.MergeAggKey, e) + if !metadataSet { + return nil, fmt.Errorf("unable to set elements metadata for copied node %v", id) + } + } + + nodes = append(nodes, nCopy) + } + + return traversal.NewGraphTraversalV(tv.GraphTraversal, nodes), nil +} + +// mergeMetadata return the merge of node.Key elements with the ones already stored in nodeMerge +// Eg.: +// node: Metadata.key: {"a":{x}, "b":{y}} +// nodeMerge: {"a":[{z}]} +// return: Metadata.key: {"a":[{x},{z}], "b":[{y}]} +// +// Ignore if Metadata.key has an invalid format (not a map). +// Reflect is used to be able to access map's defined in different types. +// Element aggregation data type should be map[string]interface{} to be able to be encoded with JSON +func mergeMetadata(node *graph.Node, key string, nodeMerge map[string][]interface{}) map[string][]interface{} { + if nodeMerge == nil { + nodeMerge = map[string][]interface{}{} + } + + n1MergeIface, n1Err := node.GetField(key) + + if n1Err == nil { + // Ignore Metadata.key values which are not a map + n1MergeValue := reflect.ValueOf(n1MergeIface) + + // If the metadata value is a pointer, resolve it + if n1MergeValue.Kind() == reflect.Ptr { + n1MergeValue = n1MergeValue.Elem() + } + + // Merge step only accepts a map as data origin + if n1MergeValue.Kind() != reflect.Map { + logging.GetLogger().Errorf("Invalid type for elements, expecting a map, but it is %v", n1MergeValue.Kind()) + return nodeMerge + } + + iter := n1MergeValue.MapRange() + NODE_MERGE: + for iter.Next() { + k := fmt.Sprintf("%v", iter.Key().Interface()) + v := iter.Value().Interface() + + // Do not append if the same element already exists + for _, storedElement := range nodeMerge[k] { + if reflect.DeepEqual(storedElement, v) { + continue NODE_MERGE + } + } + + nodeMerge[k] = append(nodeMerge[k], v) + } + } + + return nodeMerge +} diff --git a/gremlin/traversal/merge_test.go b/gremlin/traversal/merge_test.go new file mode 100644 index 0000000000..cc8b25d2c8 --- /dev/null +++ b/gremlin/traversal/merge_test.go @@ -0,0 +1,610 @@ +package traversal + +import ( + "testing" + "time" + + "github.com/skydive-project/skydive/graffiti/graph" + "github.com/skydive-project/skydive/graffiti/graph/traversal" + "github.com/stretchr/testify/assert" +) + +// FakeMergeGraphBackend simulate a backend with history that could store different revisions of nodes +type FakeMergeGraphBackend struct { + graph.MemoryBackend + Nodes []*graph.Node +} + +func (b *FakeMergeGraphBackend) IsHistorySupported() bool { + return true +} + +func (b *FakeMergeGraphBackend) GetNode(i graph.Identifier, at graph.Context) []*graph.Node { + nodes := []*graph.Node{} + for _, n := range b.Nodes { + if n.ID == i { + nodes = append(nodes, n) + } + } + return nodes +} + +func (b *FakeMergeGraphBackend) GetNodesFromIDs(identifierList []graph.Identifier, at graph.Context) []*graph.Node { + nodes := []*graph.Node{} + for _, n := range b.Nodes { + for _, i := range identifierList { + if n.ID == i { + nodes = append(nodes, n) + } + } + } + return nodes +} + +func TestMergeMetadataNilNodeMerge(t *testing.T) { + key := "Merge" + + metadataNode1 := graph.Metadata{key: map[string]interface{}{ + "abc": map[interface{}]string{"descr": "foo"}, + }} + node := CreateNode("nodeA", metadataNode1, graph.TimeUTC(), 1) + + nodeMergeAgg := mergeMetadata(node, key, nil) + + expected := map[string][]interface{}{ + "abc": { + map[interface{}]string{"descr": "foo"}, + }, + } + + assert.Equal(t, expected, nodeMergeAgg) +} + +func TestMergeMetadataPointerValue(t *testing.T) { + key := "Merge" + + value := map[string]interface{}{ + "abc": map[interface{}]string{"descr": "foo"}, + } + + metadataNode1 := graph.Metadata{key: &value} + node := CreateNode("nodeA", metadataNode1, graph.TimeUTC(), 1) + + nodeMergeAgg := mergeMetadata(node, key, nil) + + expected := map[string][]interface{}{ + "abc": { + map[interface{}]string{"descr": "foo"}, + }, + } + + assert.Equal(t, expected, nodeMergeAgg) +} + +func TestMergeMetadata(t *testing.T) { + tests := []struct { + name string + nodesMerge []interface{} + expected map[string][]interface{} + }{ + { + name: "no nodes", + expected: map[string][]interface{}{}, + }, + { + name: "one node", + nodesMerge: []interface{}{ + map[string]interface{}{ + "abc": map[string]string{"descr": "foo"}, + }}, + expected: map[string][]interface{}{ + "abc": { + map[string]string{"descr": "foo"}, + }, + }, + }, + { + name: "two nodes, different keys", + nodesMerge: []interface{}{ + map[string]interface{}{ + "abc": map[string]string{"descr": "foo"}, + }, + map[string]interface{}{ + "xyz": map[string]string{"descr": "bar"}, + }}, + expected: map[string][]interface{}{ + "abc": { + map[string]string{"descr": "foo"}, + }, + "xyz": { + map[string]string{"descr": "bar"}, + }, + }, + }, + { + name: "two nodes, same keys", + nodesMerge: []interface{}{ + map[string]interface{}{ + "abc": map[string]string{"descr": "foo"}, + }, + map[string]interface{}{ + "abc": map[string]string{"descr": "bar"}, + }}, + expected: map[string][]interface{}{ + "abc": { + map[string]string{"descr": "foo"}, + map[string]string{"descr": "bar"}, + }, + }, + }, + { + name: "two nodes, repeating one event, should be removed", + nodesMerge: []interface{}{ + map[string]interface{}{ + "abc": map[string]string{"descr": "foo"}, + }, + map[string]interface{}{ + "abc": map[string]string{"descr": "foo"}, + "xxx": map[string]string{"descr": "bar"}, + }}, + expected: map[string][]interface{}{ + "abc": { + map[string]string{"descr": "foo"}, + }, + "xxx": { + map[string]string{"descr": "bar"}, + }, + }, + }, + } + + key := "Merge" + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + nodeMergeAgg := map[string][]interface{}{} + + for _, nodeMerge := range test.nodesMerge { + metadataNode1 := graph.Metadata{key: nodeMerge} + node := CreateNode("nodeA", metadataNode1, graph.TimeUTC(), 1) + + nodeMergeAgg = mergeMetadata(node, key, nodeMergeAgg) + } + + assert.Equal(t, test.expected, nodeMergeAgg) + }) + } +} + +func TestInterfaceMerge(t *testing.T) { + tests := []struct { + name string + InNodes []*graph.Node + key string + aggKey string + startTime time.Time + endTime time.Time + // Expected nodes + OutNodes []*graph.Node + }{ + { + name: "no input nodes", + }, + { + // Node passes the step without being modified + name: "one input node without key defined", + key: "Merge", + aggKey: "MergeAgg", + InNodes: []*graph.Node{ + CreateNode("A", graph.Metadata{}, graph.Time(time.Unix(0, 0)), 1), + }, + OutNodes: []*graph.Node{ + CreateNode("A", graph.Metadata{ + "MergeAgg": map[string][]interface{}{}, + }, graph.Time(time.Unix(0, 0)), 1), + }, + }, + { + name: "one input node with key defined but empty", + key: "Merge", + aggKey: "MergeAgg", + InNodes: []*graph.Node{ + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{}, + }, graph.Time(time.Unix(0, 0)), 1), + }, + OutNodes: []*graph.Node{ + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{}, + "MergeAgg": map[string][]interface{}{}, + }, graph.Time(time.Unix(0, 0)), 1), + }, + }, + { + name: "one input node with key defined and one element", + key: "Merge", + aggKey: "MergeAgg", + InNodes: []*graph.Node{ + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{"e1": map[string]interface{}{"desc": "a"}}, + }, graph.Time(time.Unix(0, 0)), 1), + }, + OutNodes: []*graph.Node{ + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{"e1": map[string]interface{}{"desc": "a"}}, + "MergeAgg": map[string][]interface{}{"e1": { + map[string]interface{}{"desc": "a"}, + }}, + }, graph.Time(time.Unix(0, 0)), 1), + }, + }, + { + name: "one input node with a complex element", + key: "Merge", + aggKey: "MergeAgg", + InNodes: []*graph.Node{ + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{ + "e1": map[string]interface{}{ + "desc": "a", + "TTL": 45, + "Payload": []interface{}{ + map[string]interface{}{"Key": "foo"}, + map[string]interface{}{"Value": "bar"}, + }, + }, + }, + }, graph.Time(time.Unix(0, 0)), 1), + }, + OutNodes: []*graph.Node{ + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{ + "e1": map[string]interface{}{ + "desc": "a", + "TTL": 45, + "Payload": []interface{}{ + map[string]interface{}{"Key": "foo"}, + map[string]interface{}{"Value": "bar"}, + }, + }, + }, + "MergeAgg": map[string][]interface{}{ + "e1": { + map[string]interface{}{ + "desc": "a", + "TTL": 45, + "Payload": []interface{}{ + map[string]interface{}{"Key": "foo"}, + map[string]interface{}{"Value": "bar"}, + }, + }, + }, + }, + }, graph.Time(time.Unix(0, 0)), 1), + }, + }, + { + name: "two different input nodes with key defined and one element each one", + key: "Merge", + aggKey: "MergeAxx", + InNodes: []*graph.Node{ + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{"e1": map[string]interface{}{"desc": "a"}}, + }, graph.Time(time.Unix(0, 0)), 1), + CreateNode("B", graph.Metadata{ + "Merge": map[string]interface{}{"e1": map[string]interface{}{"desc": "a"}}, + }, graph.Time(time.Unix(0, 0)), 1), + }, + OutNodes: []*graph.Node{ + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{"e1": map[string]interface{}{"desc": "a"}}, + "MergeAxx": map[string][]interface{}{"e1": { + map[string]interface{}{"desc": "a"}, + }}, + }, graph.Time(time.Unix(0, 0)), 1), + CreateNode("B", graph.Metadata{ + "Merge": map[string]interface{}{"e1": map[string]interface{}{"desc": "a"}}, + "MergeAxx": map[string][]interface{}{"e1": { + map[string]interface{}{"desc": "a"}, + }}, + }, graph.Time(time.Unix(0, 0)), 1), + }, + }, + { + name: "one node, with a previous version, both without key defined", + key: "Merge", + aggKey: "MergeAgg", + InNodes: []*graph.Node{ + CreateNode("A", graph.Metadata{}, graph.Time(time.Unix(0, 0)), 1), + CreateNode("A", graph.Metadata{}, graph.Time(time.Unix(0, 0)), 2), + }, + OutNodes: []*graph.Node{ + CreateNode("A", graph.Metadata{ + "MergeAgg": map[string][]interface{}{}, + }, graph.Time(time.Unix(0, 0)), 2), + }, + }, + { + name: "one node, with a previous version, both with key defined but empty", + key: "Merge", + aggKey: "MergeAgg", + InNodes: []*graph.Node{ + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{}, + }, graph.Time(time.Unix(0, 0)), 1), + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{}, + }, graph.Time(time.Unix(0, 0)), 2), + }, + OutNodes: []*graph.Node{ + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{}, + "MergeAgg": map[string][]interface{}{}, + }, graph.Time(time.Unix(0, 0)), 2), + }, + }, + { + name: "one node, with a previous version, both with key defined, same event different content", + key: "Merge", + aggKey: "MergeAgg", + InNodes: []*graph.Node{ + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{"e1": map[string]interface{}{"desc": "a"}}, + }, graph.Time(time.Unix(0, 0)), 1), + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{"e1": map[string]interface{}{"desc": "b"}}, + }, graph.Time(time.Unix(0, 0)), 2), + }, + OutNodes: []*graph.Node{ + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{"e1": map[string]interface{}{"desc": "b"}}, + "MergeAgg": map[string][]interface{}{"e1": { + map[string]interface{}{"desc": "a"}, + map[string]interface{}{"desc": "b"}, + }}, + }, graph.Time(time.Unix(0, 0)), 2), + }, + }, + { + name: "one node, with a previous version, first one without event, second one with event", + key: "Merge", + aggKey: "MergeAgg", + InNodes: []*graph.Node{ + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{}, + }, graph.Time(time.Unix(0, 0)), 1), + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{"e1": map[string]interface{}{"desc": "b"}}, + }, graph.Time(time.Unix(0, 0)), 2), + }, + OutNodes: []*graph.Node{ + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{"e1": map[string]interface{}{"desc": "b"}}, + "MergeAgg": map[string][]interface{}{"e1": { + map[string]interface{}{"desc": "b"}, + }}, + }, graph.Time(time.Unix(0, 0)), 2), + }, + }, + { + name: "one node, with a previous version, first one with event, second one without event", + key: "Merge", + aggKey: "MergeAgg", + InNodes: []*graph.Node{ + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{"e1": map[string]interface{}{"desc": "a"}}, + }, graph.Time(time.Unix(0, 0)), 1), + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{}, + }, graph.Time(time.Unix(0, 0)), 2), + }, + OutNodes: []*graph.Node{ + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{}, + "MergeAgg": map[string][]interface{}{"e1": { + map[string]interface{}{"desc": "a"}, + }}, + }, graph.Time(time.Unix(0, 0)), 2), + }, + }, + { + name: "one node, with two previous versions, first with, second without, third with", + key: "Merge", + aggKey: "MergeAgg", + InNodes: []*graph.Node{ + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{"e1": map[string]interface{}{"desc": "a"}}, + }, graph.Time(time.Unix(0, 0)), 1), + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{}, + }, graph.Time(time.Unix(0, 0)), 2), + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{"e1": map[string]interface{}{"desc": "c"}}, + }, graph.Time(time.Unix(0, 0)), 3), + }, + OutNodes: []*graph.Node{ + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{"e1": map[string]interface{}{"desc": "c"}}, + "MergeAgg": map[string][]interface{}{"e1": { + map[string]interface{}{"desc": "a"}, + map[string]interface{}{"desc": "c"}, + }}, + }, graph.Time(time.Unix(0, 0)), 3), + }, + }, + { + name: "memory backend does not filter nodes by date, startTime and endTime are ignored", + key: "Merge", + aggKey: "MergeAgg", + startTime: time.Unix(100, 0), + endTime: time.Unix(200, 0), + InNodes: []*graph.Node{ + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{"e1": map[string]interface{}{"desc": "a"}}, + }, graph.Time(time.Unix(300, 0)), 1), + }, + OutNodes: []*graph.Node{ + CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{"e1": map[string]interface{}{"desc": "a"}}, + "MergeAgg": map[string][]interface{}{"e1": { + map[string]interface{}{"desc": "a"}, + }}, + }, graph.Time(time.Unix(300, 0)), 1), + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + b := FakeMergeGraphBackend{ + Nodes: test.InNodes, + } + g := graph.NewGraph("testhost", &b, "analyzer.testhost") + + gt := traversal.NewGraphTraversal(g, false) + tvIn := traversal.NewGraphTraversalV(gt, test.InNodes) + + s := MergeGremlinTraversalStep{ + MergeKey: test.key, + MergeAggKey: test.aggKey, + StartTime: test.startTime, + EndTime: test.endTime, + } + ts, err := s.InterfaceMerge(tvIn) + if err != nil { + t.Error(err.Error()) + } + + tvOut, ok := ts.(*traversal.GraphTraversalV) + if !ok { + t.Errorf("Invalid GraphTraversal type") + } + + assert.ElementsMatch(t, test.OutNodes, tvOut.GetNodes()) + }) + } +} + +func TestInterfaceMergeDoNotModifyOriginNodes(t *testing.T) { + n := CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{"e1": map[string]interface{}{"desc": "a"}}, + }, graph.Time(time.Unix(0, 0)), 1) + + nCopy := CreateNode("A", graph.Metadata{ + "Merge": map[string]interface{}{"e1": map[string]interface{}{"desc": "a"}}, + }, graph.Time(time.Unix(0, 0)), 1) + + b := FakeMergeGraphBackend{ + Nodes: []*graph.Node{n}, + } + g := graph.NewGraph("testhost", &b, "analyzer.testhost") + + gt := traversal.NewGraphTraversal(g, false) + tvIn := traversal.NewGraphTraversalV(gt, []*graph.Node{n}) + + s := MergeGremlinTraversalStep{ + MergeKey: "Merge", + MergeAggKey: "AggMerge", + } + _, err := s.InterfaceMerge(tvIn) + assert.Nil(t, err) + + // Node stored in the graph should not be modified + assert.Equal(t, b.GetNode("A", graph.Context{})[0], nCopy) +} + +func TestEventsParseStep(t *testing.T) { + tests := []struct { + name string + token traversal.Token + traversalCtx traversal.GremlinTraversalContext + expectedTraversalStep traversal.GremlinTraversalStep + expectedError string + }{ + { + name: "non merge token", + token: traversal.COUNT, + }, + { + name: "nil traversalCtx", + token: traversalMergeToken, + expectedError: "Merge parameter must have two or four parameters (OriginKey, DestinationKey, StartTime, EndTime)", + }, + { + name: "only one param", + token: traversalMergeToken, + traversalCtx: traversal.GremlinTraversalContext{ + Params: []interface{}{"foo"}, + }, + expectedError: "Merge parameter must have two or four parameters (OriginKey, DestinationKey, StartTime, EndTime)", + }, + { + name: "two param not string", + token: traversalMergeToken, + traversalCtx: traversal.GremlinTraversalContext{ + Params: []interface{}{1, 2}, + }, + expectedError: "Merge first parameter have to be a string", + }, + { + name: "two string params", + token: traversalMergeToken, + traversalCtx: traversal.GremlinTraversalContext{ + Params: []interface{}{"key", "aggKey"}, + }, + expectedTraversalStep: &MergeGremlinTraversalStep{ + GremlinTraversalContext: traversal.GremlinTraversalContext{ + Params: []interface{}{"key", "aggKey"}, + }, + MergeKey: "key", + MergeAggKey: "aggKey", + }, + }, + { + name: "four valid params", + token: traversalMergeToken, + traversalCtx: traversal.GremlinTraversalContext{ + Params: []interface{}{"key", "aggKey", int64(1627987976), int64(1627987977)}, + }, + expectedTraversalStep: &MergeGremlinTraversalStep{ + GremlinTraversalContext: traversal.GremlinTraversalContext{ + Params: []interface{}{"key", "aggKey", int64(1627987976), int64(1627987977)}, + }, + MergeKey: "key", + MergeAggKey: "aggKey", + StartTime: time.Unix(1627987976, 0), + EndTime: time.Unix(1627987977, 0), + }, + }, + { + name: "invalid start date", + token: traversalMergeToken, + traversalCtx: traversal.GremlinTraversalContext{ + Params: []interface{}{"foo", "bar", "123456789", "123456789"}, + }, + expectedError: "Merge third parameter have to be a int (unix epoch time)", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := MergeTraversalExtension{MergeToken: traversalMergeToken} + + traversalStep, err := e.ParseStep(test.token, test.traversalCtx) + if test.expectedError != "" { + assert.EqualErrorf(t, err, test.expectedError, "error") + } else { + assert.Nil(t, err, "nil error") + } + + assert.Equalf(t, test.expectedTraversalStep, traversalStep, "step") + }) + } +} + +// CreateNode func to create nodes with a specific node revision +func CreateNode(id string, m graph.Metadata, t graph.Time, revision int64) *graph.Node { + n := graph.CreateNode(graph.Identifier(id), m, t, "host", "orig") + n.Revision = revision + return n +} diff --git a/gremlin/traversal/neighbors.go b/gremlin/traversal/neighbors.go new file mode 100644 index 0000000000..ec3eee009c --- /dev/null +++ b/gremlin/traversal/neighbors.go @@ -0,0 +1,192 @@ +package traversal + +import ( + "github.com/pkg/errors" + + "github.com/skydive-project/skydive/graffiti/filters" + "github.com/skydive-project/skydive/graffiti/graph" + "github.com/skydive-project/skydive/graffiti/graph/traversal" + "github.com/skydive-project/skydive/topology" +) + +// NeighborsTraversalExtension describes a new extension to enhance the topology +type NeighborsTraversalExtension struct { + NeighborsToken traversal.Token +} + +// NeighborsGremlinTraversalStep navigate the graph starting from a node, following edges +// from parent to child and from child to parent. +// It follows the same sintaxis as Ascendants and Descendants step. +// The behaviour is like Ascendants+Descendants combined. +// If only one param is defined, it is used as depth, eg: G.V('id').Neighbors(4) +// If we have an event number of parameters, they are used as edge filter, and +// depth is defaulted to one, eg.: G.V('id').Neighbors("Type","foo","RelationType","bar") +// If we have an odd, but >1, number of parameters, all but the last one are used as +// edge filters and the last one as depth, eg.: G.V('id').Neighbors("Type","foo","RelationType","bar",3) +type NeighborsGremlinTraversalStep struct { + context traversal.GremlinTraversalContext + maxDepth int64 + edgeFilter graph.ElementMatcher + // nextStepOnlyIDs is set to true if the next step only needs node IDs and not the whole node info + nextStepOnlyIDs bool +} + +// NewNeighborsTraversalExtension returns a new graph traversal extension +func NewNeighborsTraversalExtension() *NeighborsTraversalExtension { + return &NeighborsTraversalExtension{ + NeighborsToken: traversalNeighborsToken, + } +} + +// ScanIdent returns an associated graph token +func (e *NeighborsTraversalExtension) ScanIdent(s string) (traversal.Token, bool) { + switch s { + case "NEIGHBORS": + return e.NeighborsToken, true + } + return traversal.IDENT, false +} + +// ParseStep parses neighbors step +func (e *NeighborsTraversalExtension) ParseStep(t traversal.Token, p traversal.GremlinTraversalContext) (traversal.GremlinTraversalStep, error) { + switch t { + case e.NeighborsToken: + default: + return nil, nil + } + + maxDepth := int64(1) + edgeFilter, _ := topology.OwnershipMetadata().Filter() + + switch len(p.Params) { + case 0: + default: + i := len(p.Params) / 2 * 2 + filter, err := traversal.ParamsToFilter(filters.BoolFilterOp_OR, p.Params[:i]...) + if err != nil { + return nil, errors.Wrap(err, "Neighbors accepts an optional number of key/value tuples and an optional depth") + } + edgeFilter = filter + + if i == len(p.Params) { + break + } + + fallthrough + case 1: + depth, ok := p.Params[len(p.Params)-1].(int64) + if !ok { + return nil, errors.New("Neighbors last argument must be the maximum depth specified as an integer") + } + maxDepth = depth + } + + return &NeighborsGremlinTraversalStep{context: p, maxDepth: maxDepth, edgeFilter: graph.NewElementFilter(edgeFilter)}, nil +} + +// getNeighbors given a list of nodes, get its neighbors nodes for "maxDepth" depth relationships. +// Edges between nodes must fulfill "edgeFilter" filter. +// Nodes passed to this function will always be in the response. +func (d *NeighborsGremlinTraversalStep) getNeighbors(g *graph.Graph, nodes []*graph.Node) []*graph.Node { + // visitedNodes store neighors and avoid visiting twice the same node + visitedNodes := map[graph.Identifier]interface{}{} + + // currentDepthNodesIDs slice with the nodes being processed in each depth. + // We use "empty" while procesing the neighbors nodes to avoid extra calls to the backend. + var currentDepthNodesIDs []graph.Identifier + // nextDepthNodes slice were next depth nodes are being stored. + // Initializated with the list of origin nodes where it should start from. + nextDepthNodesIDs := make([]graph.Identifier, 0, len(nodes)) + + // Mark origin nodes as already visited + // Neighbor step will return also the origin nodes + for _, n := range nodes { + visitedNodes[n.ID] = struct{}{} + nextDepthNodesIDs = append(nextDepthNodesIDs, n.ID) + } + + // DFS + // BFS must not be used because could lead to ignore some servers in this case: + // A -> B + // B -> C + // C -> D + // A -> C + // With depth=2, BFS will return A,B,C (C is visited in A->B->C, si ignored in A->C->D) + // DFS will return, the correct, A,B,C,D + for i := 0; i < int(d.maxDepth); i++ { + // Copy values from nextDepthNodes to currentDepthNodes + currentDepthNodesIDs = make([]graph.Identifier, len(nextDepthNodesIDs)) + copy(currentDepthNodesIDs, nextDepthNodesIDs) + + nextDepthNodesIDs = nextDepthNodesIDs[:0] // Clean slice, keeping capacity + // Get all edges for the list of nodes, filtered by edgeFilter + // Convert the list of node ids to a list of nodes + + currentDepthNodes := make([]*graph.Node, 0, len(currentDepthNodesIDs)) + for _, nID := range currentDepthNodesIDs { + currentDepthNodes = append(currentDepthNodes, graph.CreateNode(nID, graph.Metadata{}, graph.Unix(0, 0), "", "")) + } + edges := g.GetNodesEdges(currentDepthNodes, d.edgeFilter) + + for _, e := range edges { + // Get nodeID of the other side of the edge + // Store neighbors + // We don't know in which side of the edge are the neighbors, so, add both sides if not already visited + _, okParent := visitedNodes[e.Parent] + if !okParent { + visitedNodes[e.Parent] = struct{}{} + // Do not walk nodes already processed + nextDepthNodesIDs = append(nextDepthNodesIDs, e.Parent) + } + _, okChild := visitedNodes[e.Child] + if !okChild { + visitedNodes[e.Child] = struct{}{} + nextDepthNodesIDs = append(nextDepthNodesIDs, e.Child) + } + } + } + + // Return "empty" nodes (just with the ID) if the next step only need that info + if d.nextStepOnlyIDs { + ret := make([]*graph.Node, 0, len(visitedNodes)) + for nID := range visitedNodes { + ret = append(ret, graph.CreateNode(nID, graph.Metadata{}, graph.Unix(0, 0), "", "")) + } + return ret + } + + // Get concurrentl all nodes for the list of neighbors ids + nodesIDs := make([]graph.Identifier, 0, len(visitedNodes)) + for n := range visitedNodes { + nodesIDs = append(nodesIDs, n) + } + + return g.GetNodesFromIDs(nodesIDs) +} + +// Exec Neighbors step +func (d *NeighborsGremlinTraversalStep) Exec(last traversal.GraphTraversalStep) (traversal.GraphTraversalStep, error) { + switch tv := last.(type) { + case *traversal.GraphTraversalV: + tv.GraphTraversal.RLock() + neighbors := d.getNeighbors(tv.GraphTraversal.Graph, tv.GetNodes()) + tv.GraphTraversal.RUnlock() + + return traversal.NewGraphTraversalV(tv.GraphTraversal, neighbors), nil + } + return nil, traversal.ErrExecutionError +} + +// Reduce Neighbors step +func (d *NeighborsGremlinTraversalStep) Reduce(next traversal.GremlinTraversalStep) (traversal.GremlinTraversalStep, error) { + // Merge step only needs the ids of nodes. Saving some queries. + if _, ok := next.(*MergeGremlinTraversalStep); ok { + d.nextStepOnlyIDs = true + } + return next, nil +} + +// Context Neighbors step +func (d *NeighborsGremlinTraversalStep) Context() *traversal.GremlinTraversalContext { + return &d.context +} diff --git a/gremlin/traversal/neighbors_test.go b/gremlin/traversal/neighbors_test.go new file mode 100644 index 0000000000..72e7f2f2b5 --- /dev/null +++ b/gremlin/traversal/neighbors_test.go @@ -0,0 +1,489 @@ +package traversal + +import ( + "fmt" + "testing" + "time" + + "github.com/skydive-project/skydive/graffiti/filters" + "github.com/skydive-project/skydive/graffiti/graph" + "github.com/skydive-project/skydive/graffiti/graph/traversal" + "github.com/skydive-project/skydive/topology" + "github.com/stretchr/testify/assert" +) + +// FakeNeighborsSlowGraphBackend simulate a backend with history that could store different revisions of nodes +type FakeNeighborsSlowGraphBackend struct { + Backend *graph.MemoryBackend +} + +func (f *FakeNeighborsSlowGraphBackend) NodeAdded(n *graph.Node) error { + return f.Backend.NodeAdded(n) +} + +func (f *FakeNeighborsSlowGraphBackend) NodeDeleted(n *graph.Node) error { + return f.Backend.NodeDeleted(n) +} + +func (f *FakeNeighborsSlowGraphBackend) GetNode(i graph.Identifier, at graph.Context) []*graph.Node { + time.Sleep(20 * time.Millisecond) + return f.Backend.GetNode(i, at) +} + +func (f *FakeNeighborsSlowGraphBackend) GetNodesFromIDs(i []graph.Identifier, at graph.Context) []*graph.Node { + time.Sleep(40 * time.Millisecond) + return f.Backend.GetNodesFromIDs(i, at) +} + +func (f *FakeNeighborsSlowGraphBackend) GetNodeEdges(n *graph.Node, at graph.Context, m graph.ElementMatcher) []*graph.Edge { + time.Sleep(20 * time.Millisecond) + return f.Backend.GetNodeEdges(n, at, m) +} + +func (f *FakeNeighborsSlowGraphBackend) GetNodesEdges(n []*graph.Node, at graph.Context, m graph.ElementMatcher) []*graph.Edge { + time.Sleep(40 * time.Millisecond) + return f.Backend.GetNodesEdges(n, at, m) +} + +func (f *FakeNeighborsSlowGraphBackend) EdgeAdded(e *graph.Edge) error { + return f.Backend.EdgeAdded(e) +} + +func (f *FakeNeighborsSlowGraphBackend) EdgeDeleted(e *graph.Edge) error { + return f.Backend.EdgeDeleted(e) +} + +func (f *FakeNeighborsSlowGraphBackend) GetEdge(i graph.Identifier, at graph.Context) []*graph.Edge { + return f.Backend.GetEdge(i, at) +} + +func (f *FakeNeighborsSlowGraphBackend) GetEdgeNodes(e *graph.Edge, at graph.Context, parentMetadata graph.ElementMatcher, childMetadata graph.ElementMatcher) ([]*graph.Node, []*graph.Node) { + return f.Backend.GetEdgeNodes(e, at, parentMetadata, childMetadata) +} + +func (f *FakeNeighborsSlowGraphBackend) MetadataUpdated(e interface{}) error { + return f.Backend.MetadataUpdated(e) +} + +func (f *FakeNeighborsSlowGraphBackend) GetNodes(t graph.Context, e graph.ElementMatcher) []*graph.Node { + return f.Backend.GetNodes(t, e) +} + +func (f *FakeNeighborsSlowGraphBackend) GetEdges(t graph.Context, e graph.ElementMatcher) []*graph.Edge { + return f.Backend.GetEdges(t, e) +} + +func (f *FakeNeighborsSlowGraphBackend) IsHistorySupported() bool { + return f.Backend.IsHistorySupported() +} + +func TestGetNeighbors(t *testing.T) { + testCases := []struct { + desc string + graphNodes []*graph.Node + graphEdges []*graph.Edge + originNodes []*graph.Node + maxDepth int64 + edgeFilter graph.ElementMatcher + onlyIDs bool + expectedNodes []*graph.Node + }{ + { + desc: "one graph node", + graphNodes: []*graph.Node{ + graph.CreateNode(graph.Identifier("A"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + }, + graphEdges: []*graph.Edge{}, + originNodes: []*graph.Node{ + graph.CreateNode(graph.Identifier("A"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + }, + maxDepth: 0, + edgeFilter: nil, + expectedNodes: []*graph.Node{ + graph.CreateNode(graph.Identifier("A"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + }, + }, + { + desc: "one graph node with only ids strip all node data except id", + graphNodes: []*graph.Node{ + graph.CreateNode(graph.Identifier("A"), graph.Metadata{"foo": "bar"}, graph.Unix(100, 0), "host", "origin"), + }, + graphEdges: []*graph.Edge{}, + originNodes: []*graph.Node{ + graph.CreateNode(graph.Identifier("A"), graph.Metadata{"foo": "bar"}, graph.Unix(100, 0), "host", "origin"), + }, + maxDepth: 0, + edgeFilter: nil, + onlyIDs: true, + expectedNodes: []*graph.Node{ + graph.CreateNode(graph.Identifier("A"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + }, + }, + { + desc: "interface connected to host and to other interface", + graphNodes: []*graph.Node{ + graph.CreateNode(graph.Identifier("HostA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntB"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + }, + graphEdges: []*graph.Edge{ + graph.CreateEdge( + graph.Identifier("HostA-IntA"), + graph.CreateNode(graph.Identifier("HostA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.Metadata{"RelationType": "ownership"}, + graph.Unix(0, 0), + "", + "", + ), + graph.CreateEdge( + graph.Identifier("IntA-IntB"), + graph.CreateNode(graph.Identifier("IntA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntB"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.Metadata{"RelationType": "ConnectsTo"}, + graph.Unix(0, 0), + "", + "", + ), + }, + originNodes: []*graph.Node{ + graph.CreateNode(graph.Identifier("IntA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + }, + maxDepth: 1, + edgeFilter: nil, + expectedNodes: []*graph.Node{ + graph.CreateNode(graph.Identifier("HostA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntB"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + }, + }, + { + desc: "host connected to interface and that to other interface, depth 2", + graphNodes: []*graph.Node{ + graph.CreateNode(graph.Identifier("HostA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntB"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + }, + graphEdges: []*graph.Edge{ + graph.CreateEdge( + graph.Identifier("HostA-IntA"), + graph.CreateNode(graph.Identifier("HostA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.Metadata{"RelationType": "ownership"}, + graph.Unix(0, 0), + "", + "", + ), + graph.CreateEdge( + graph.Identifier("IntA-IntB"), + graph.CreateNode(graph.Identifier("IntA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntB"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.Metadata{"RelationType": "ConnectsTo"}, + graph.Unix(0, 0), + "", + "", + ), + }, + originNodes: []*graph.Node{ + graph.CreateNode(graph.Identifier("HostA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + }, + maxDepth: 2, + edgeFilter: nil, + expectedNodes: []*graph.Node{ + graph.CreateNode(graph.Identifier("HostA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntB"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + }, + }, + { + desc: "two hosts connected through interfaces, depth 3", + graphNodes: []*graph.Node{ + graph.CreateNode(graph.Identifier("HostA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("HostB"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntB"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + }, + graphEdges: []*graph.Edge{ + graph.CreateEdge( + graph.Identifier("HostA-IntA"), + graph.CreateNode(graph.Identifier("HostA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.Metadata{"RelationType": "ownership"}, + graph.Unix(0, 0), + "", + "", + ), + graph.CreateEdge( + graph.Identifier("HostB-IntB"), + graph.CreateNode(graph.Identifier("HostB"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntB"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.Metadata{"RelationType": "ownership"}, + graph.Unix(0, 0), + "", + "", + ), + graph.CreateEdge( + graph.Identifier("IntA-IntB"), + graph.CreateNode(graph.Identifier("IntA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntB"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.Metadata{"RelationType": "ConnectsTo"}, + graph.Unix(0, 0), + "", + "", + ), + }, + originNodes: []*graph.Node{ + graph.CreateNode(graph.Identifier("HostA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + }, + maxDepth: 3, + edgeFilter: nil, + expectedNodes: []*graph.Node{ + graph.CreateNode(graph.Identifier("HostA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("HostB"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntB"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + }, + }, { + desc: "two hosts connected through interfaces, reverse connection, depth 3", + graphNodes: []*graph.Node{ + graph.CreateNode(graph.Identifier("HostA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("HostB"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntB"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + }, + graphEdges: []*graph.Edge{ + graph.CreateEdge( + graph.Identifier("HostA-IntA"), + graph.CreateNode(graph.Identifier("HostA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.Metadata{"RelationType": "ownership"}, + graph.Unix(0, 0), + "", + "", + ), + graph.CreateEdge( + graph.Identifier("HostB-IntB"), + graph.CreateNode(graph.Identifier("HostB"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntB"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.Metadata{"RelationType": "ownership"}, + graph.Unix(0, 0), + "", + "", + ), + graph.CreateEdge( + graph.Identifier("IntB-IntA"), + graph.CreateNode(graph.Identifier("IntB"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.Metadata{"RelationType": "ConnectsTo"}, + graph.Unix(0, 0), + "", + "", + ), + }, + originNodes: []*graph.Node{ + graph.CreateNode(graph.Identifier("HostA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + }, + maxDepth: 3, + edgeFilter: nil, + expectedNodes: []*graph.Node{ + graph.CreateNode(graph.Identifier("HostA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntA"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("HostB"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + graph.CreateNode(graph.Identifier("IntB"), graph.Metadata{}, graph.Unix(0, 0), "", ""), + }, + }, + } + + for _, tC := range testCases { + t.Run(tC.desc, func(t *testing.T) { + b, err := graph.NewMemoryBackend() + if err != nil { + t.Error(err.Error()) + } + g := graph.NewGraph("testhost", b, "analyzer.testhost") + + for _, n := range tC.graphNodes { + err := g.AddNode(n) + if err != nil { + t.Error(err.Error()) + } + } + + for _, e := range tC.graphEdges { + err := g.AddEdge(e) + if err != nil { + t.Error(err.Error()) + } + } + + d := NeighborsGremlinTraversalStep{ + maxDepth: tC.maxDepth, + edgeFilter: tC.edgeFilter, + nextStepOnlyIDs: tC.onlyIDs, + } + neighbors := d.getNeighbors(g, tC.originNodes) + + assert.ElementsMatch(t, neighbors, tC.expectedNodes) + + }) + } +} + +func TestNeighborsParseStep(t *testing.T) { + ownershipFilter, err := topology.OwnershipMetadata().Filter() + assert.Nil(t, err) + + relationTypeFooFilter, err := traversal.ParamsToFilter(filters.BoolFilterOp_OR, "RelationType", "foo") + assert.Nil(t, err) + + relationTypeFooTypeBarFilter, err := traversal.ParamsToFilter(filters.BoolFilterOp_OR, "RelationType", "foo", "Type", "bar") + assert.Nil(t, err) + + tests := []struct { + name string + token traversal.Token + traversalCtx traversal.GremlinTraversalContext + expectedTraversalStep traversal.GremlinTraversalStep + expectedError string + }{ + { + name: "non merge token", + token: traversal.COUNT, + }, + { + name: "nil traversalCtx is default values, depth one and ownership edge filter", + token: traversalNeighborsToken, + expectedTraversalStep: &NeighborsGremlinTraversalStep{ + maxDepth: 1, + edgeFilter: graph.NewElementFilter(ownershipFilter), + }, + }, + { + name: "one string param", + token: traversalNeighborsToken, + traversalCtx: traversal.GremlinTraversalContext{ + Params: []interface{}{"foo"}, + }, + expectedError: "Neighbors last argument must be the maximum depth specified as an integer", + }, + { + name: "only one param, int number, is depth", + token: traversalNeighborsToken, + traversalCtx: traversal.GremlinTraversalContext{ + Params: []interface{}{int64(3)}, + }, + expectedTraversalStep: &NeighborsGremlinTraversalStep{ + context: traversal.GremlinTraversalContext{ + Params: []interface{}{int64(3)}, + }, + maxDepth: 3, + edgeFilter: graph.NewElementFilter(ownershipFilter), + }, + }, + { + name: "two string params are used as edge filter", + token: traversalNeighborsToken, + traversalCtx: traversal.GremlinTraversalContext{ + Params: []interface{}{"RelationType", "foo"}, + }, + expectedTraversalStep: &NeighborsGremlinTraversalStep{ + context: traversal.GremlinTraversalContext{ + Params: []interface{}{"RelationType", "foo"}, + }, + maxDepth: 1, + edgeFilter: graph.NewElementFilter(relationTypeFooFilter), + }, + }, + { + name: "four string params are used as edge filter and last int64 as depth", + token: traversalNeighborsToken, + traversalCtx: traversal.GremlinTraversalContext{ + Params: []interface{}{"RelationType", "foo", "Type", "bar", int64(5)}, + }, + expectedTraversalStep: &NeighborsGremlinTraversalStep{ + context: traversal.GremlinTraversalContext{ + Params: []interface{}{"RelationType", "foo", "Type", "bar", int64(5)}, + }, + maxDepth: 5, + edgeFilter: graph.NewElementFilter(relationTypeFooTypeBarFilter), + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := NeighborsTraversalExtension{NeighborsToken: traversalNeighborsToken} + + traversalStep, err := e.ParseStep(test.token, test.traversalCtx) + if test.expectedError != "" { + assert.EqualErrorf(t, err, test.expectedError, "error") + } else { + assert.Nil(t, err, "nil error") + } + + assert.Equalf(t, test.expectedTraversalStep, traversalStep, "step") + }) + } +} + +func BenchmarkGetNeighbors(b *testing.B) { + // Create graph with nodes and edges + backend, err := graph.NewMemoryBackend() + if err != nil { + b.Error(err.Error()) + } + + slowBackend := FakeNeighborsSlowGraphBackend{backend} + g := graph.NewGraph("testhost", &slowBackend, "analyzer.testhost") + + parentNodes := 20 + + var node *graph.Node + var nodeChild *graph.Node + for n := 0; n < parentNodes; n++ { + node, err = g.NewNode(graph.Identifier(fmt.Sprintf("%d", n)), graph.Metadata{}) + if err != nil { + b.Error(err.Error()) + } + + // Childs of this node + for nc := 0; nc < 60; nc++ { + nodeChild, err = g.NewNode(graph.Identifier(fmt.Sprintf("%d-%d", n, nc)), graph.Metadata{}) + if err != nil { + b.Error(err.Error()) + } + + _, err = g.NewEdge(graph.Identifier(fmt.Sprintf("%d-%d", n, nc)), node, nodeChild, graph.Metadata{}) + if err != nil { + b.Error(err.Error()) + } + } + } + + // Each node connects with its next + nextNodeConnect := 5 + for n := 0; n < parentNodes-nextNodeConnect; n++ { + for p := 1; p < nextNodeConnect; p++ { + // Connect interfaces + ifaceParentNode := g.GetNode(graph.Identifier(fmt.Sprintf("%d-%d", n, p))) + ifaceChildNode := g.GetNode(graph.Identifier(fmt.Sprintf("%d-%d", n+p, n))) + + _, err = g.NewEdge(graph.Identifier(fmt.Sprintf("c-%d-%d", n, p)), ifaceParentNode, ifaceChildNode, graph.Metadata{}) + if err != nil { + b.Error(err.Error()) + } + } + + } + + // Using depth=8 we get a total of 798 neighbors + d := NeighborsGremlinTraversalStep{ + maxDepth: 8, + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + d.getNeighbors(g, []*graph.Node{g.GetNode(graph.Identifier("1"))}) + } +} diff --git a/gremlin/traversal/token.go b/gremlin/traversal/token.go index e6604e9c89..3995efd185 100644 --- a/gremlin/traversal/token.go +++ b/gremlin/traversal/token.go @@ -34,4 +34,6 @@ const ( traversalGroupToken traversal.Token = 1012 traversalMoreThanToken traversal.Token = 1013 traversalAscendantsToken traversal.Token = 1014 + traversalNeighborsToken traversal.Token = 1015 + traversalMergeToken traversal.Token = 1016 ) diff --git a/validator/validator.go b/validator/validator.go index f721939b64..5d041794d3 100644 --- a/validator/validator.go +++ b/validator/validator.go @@ -177,8 +177,10 @@ func isGremlinExpr(v interface{}, param string) error { tr.AddTraversalExtension(ge.NewRawPacketsTraversalExtension()) tr.AddTraversalExtension(ge.NewDescendantsTraversalExtension()) tr.AddTraversalExtension(ge.NewAscendantsTraversalExtension()) + tr.AddTraversalExtension(ge.NewNeighborsTraversalExtension()) tr.AddTraversalExtension(ge.NewNextHopTraversalExtension()) tr.AddTraversalExtension(ge.NewGroupTraversalExtension()) + tr.AddTraversalExtension(ge.NewMergeTraversalExtension()) if _, err := tr.Parse(strings.NewReader(query)); err != nil { return GremlinNotValid(err)