Skip to content

mercury.graph.core

mercury.graph.core.Graph(data=None, keys=None, nodes=None)

This is the main class in mercury.graph.

This class seamlessly abstracts the underlying technology used to represent the graph. You can create a graph passing the following objects to the constructor:

  • A pandas DataFrame containing edges (with a keys dictionary to specify the columns and possibly a nodes DataFrame)
  • A pyspark DataFrame containing edges (with a keys dictionary to specify the columns and possibly a nodes DataFrame)
  • A networkx graph
  • A graphframes graph

Bear in mind that the graph object is immutable. This means that you can't modify the graph object once it has been created. If you want to modify it, you have to create a new graph object.

The graph object provides:

  • Properties to access the graph in different formats (networkx, graphframes, dgl)
  • Properties with metrics and summary information that are calculated on demand and technology independent.
  • It is inherited by other graph classes in mercury-graph providing ML algorithms such as graph embedding, visualization, etc.

Using this class from the other classes in mercury-graph:

The other classes in mercury-graph define models or functionalities that are based on graphs. They use a Scikit-learn-like API to interact with the graph object. This means that the graph object is passed to the class constructor and the class follow the Scikit-learn conventions. It is recommended to follow the same conventions when creating your own classes to work with mercury-graph.

The conventions can be found here:

Parameters:

Name Type Description Default
data (DataFrame, Graph or DataFrame)

The data to create the graph from. It can be a pandas DataFrame, a networkx Graph, a pyspark DataFrame, or a Graphframe. In case it already contains a graph (networkx or graphframes), the keys and nodes arguments are ignored.

None
keys dict

A dictionary with keys to specify the columns in the data DataFrame. The keys are:

  • 'src': The name of the column with the source node.
  • 'dst': The name of the column with the destination node.
  • 'id': The name of the column with the node id.
  • 'weight': The name of the column with the edge weight.
  • 'directed': A boolean to specify if the graph is directed. (Only for pyspark DataFrames)

When the keys argument is not provided or the key is missing, the default values are:

  • 'src': 'src'
  • 'dst': 'dst'
  • 'id': 'id'
  • 'weight': 'weight'
  • 'directed': True
None
nodes DataFrame

A pandas DataFrame or a pyspark DataFrame with the nodes data. (Only when data is pandas or pyspark DataFrame and with the same type as data) If not given, the nodes are inferred from the edges DataFrame.

None
Source code in mercury/graph/core/graph.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def __init__(self, data = None, keys = None, nodes = None):
    self._as_networkx = None
    self._as_graphframe = None
    self._as_dgl = None
    self._degree = None
    self._in_degree = None
    self._out_degree = None
    self._closeness_centrality = None
    self._betweenness_centrality = None
    self._pagerank = None
    self._connected_components = None
    self._nodes_colnames = None
    self._edges_colnames = None

    self._number_of_nodes = 0
    self._number_of_edges = 0
    self._node_ix = 0
    self._is_directed = False
    self._is_weighted = False

    self._init_values = {k: v for k, v in locals().items() if k in inspect.signature(self.__init__).parameters}

    if type(data) == pd.core.frame.DataFrame:
        self._from_pandas(data, nodes, keys)
        return

    if isinstance(data, nx.Graph):      # This is the most general case, including: ...Graph, ...DiGraph and ...MultiGraph
        self._from_networkx(data)
        return

    spark_int = SparkInterface()

    if pyspark_installed and graphframes_installed:
        if type(data) == spark_int.type_spark_dataframe:
            self._from_dataframe(data, nodes, keys)
            return

        if type(data) == spark_int.type_graphframe:
            self._from_graphframes(data)
            return

    raise ValueError('Invalid input data. (Expected: pandas DataFrame, a networkx Graph, a pyspark DataFrame, a graphframes Graph.)')

betweenness_centrality property

Returns the betweenness centrality of each node in the graph as a Python dictionary.

closeness_centrality property

Returns the closeness centrality of each node in the graph as a Python dictionary.

connected_components property

Returns the connected components of each node in the graph as a Python dictionary.

degree property

Returns the degree of each node in the graph as a Python dictionary.

dgl property

Returns the graph as a DGL graph.

If the graph has not been converted to a DGL graph yet, it will be converted and cached for future use.

Returns:

Type Description
DGLGraph

The graph represented as a DGL graph.

edges property

Returns an iterator over the edges in the graph.

Returns:

Type Description
EdgeIterator

An iterator object that allows iterating over the edges in the graph.

edges_colnames property

Returns the column names of the edges DataFrame.

graphframe property

Returns the graph as a GraphFrame.

If the graph has not been converted to a GraphFrame yet, it will be converted and cached for future use.

Returns:

Type Description
GraphFrame

The graph represented as a GraphFrame.

in_degree property

Returns the in-degree of each node in the graph as a Python dictionary.

is_directed property

Returns True if the graph is directed, False otherwise.

Note

Graphs created using graphframes are always directed. The way around it is to add the reverse edges to the graph. This can be done by creating the Graph with pyspark DataFrame() and defining a key 'directed' set as False in the dict argument. Otherwise, the graph will be considered directed even if these reversed edges have been created by other means this class cannot be aware of.

is_weighted property

Returns True if the graph is weighted, False otherwise.

A graph is considered weight if it has a column named 'weight' in the edges DataFrame or the column has a different name and that name is passed in the dict argument as the 'weight' key.

networkx property

Returns the graph representation as a NetworkX graph.

If the graph has not been converted to NetworkX format yet, it will be converted and cached for future use.

Returns:

Type Description
Graph

The graph representation as a NetworkX graph.

nodes property

Returns an iterator over all the nodes in the graph.

Returns:

Type Description
NodeIterator

An iterator that yields each node in the graph.

nodes_colnames property

Returns the column names of the nodes DataFrame.

number_of_edges property

Returns the number of edges in the graph.

Returns:

Type Description
int

The number of edges in the graph.

number_of_nodes property

Returns the number of nodes in the graph.

Returns:

Type Description
int

The number of nodes in the graph.

out_degree property

Returns the out-degree of each node in the graph as a Python dictionary.

pagerank property

Returns the PageRank of each node in the graph as a Python dictionary.

_calculate_betweenness_centrality()

This internal method handles the logic of a property. It returns the betweenness centrality of each node in the graph as a Python dictionary. NOTE: This method converts the graph to a networkx graph to calculate the betweenness centrality since the algorithm is too computationally expensive to use on large graphs.

Source code in mercury/graph/core/graph.py
708
709
710
711
712
713
714
def _calculate_betweenness_centrality(self):
    """
    This internal method handles the logic of a property. It returns the betweenness centrality of each node in the graph as a Python
    dictionary. NOTE: This method converts the graph to a networkx graph to calculate the betweenness centrality since the algorithm
    is too computationally expensive to use on large graphs.
    """
    return nx.betweenness_centrality(self.networkx)

_calculate_closeness_centrality()

This internal method handles the logic of a property. It returns the closeness centrality of each node in the graph as a Python dictionary.

Source code in mercury/graph/core/graph.py
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
def _calculate_closeness_centrality(self):
    """
    This internal method handles the logic of a property. It returns the closeness centrality of each node in the graph as
    a Python dictionary.
    """
    if self._as_networkx is not None:
        return nx.closeness_centrality(self._as_networkx)

    nodes = [row['id'] for row in self.graphframe.vertices.select('id').collect()]
    paths = self.graphframe.shortestPaths(landmarks = nodes)
    expr  = SparkInterface().pyspark.sql.functions.expr
    sums  = paths.withColumn('sums', expr('aggregate(map_values(distances), 0, (acc, x) -> acc + x)'))

    cc = sums.withColumn('cc', (self.number_of_nodes - 1)/sums['sums']).select('id', 'cc')

    return {row['id']: row['cc'] for row in cc.collect()}

_calculate_connected_components()

This internal method handles the logic of a property. It returns the connected components of each node in the graph as a Python dictionary.

Source code in mercury/graph/core/graph.py
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
def _calculate_connected_components(self):
    """
    This internal method handles the logic of a property. It returns the connected components of each node in the graph as a Python
    dictionary.
    """
    if self._as_networkx is not None:
        if self._is_directed:
            G = self._as_networkx.to_undirected()
        else:
            G = self._as_networkx

        graphs = (G.subgraph(c) for c in nx.connected_components(G))
        cc = dict()
        for i, graph in enumerate(graphs):
            n = graph.number_of_nodes()
            for nid in graph.nodes:
                cc[nid] = {'cc_id' : i, 'cc_size' : n}

        return cc

    graphs = self.graphframe.connectedComponents(algorithm = 'graphx')
    cc_size = graphs.select('id', 'component').groupBy('component').count()
    cc_all = graphs.select('id', 'component').join(cc_size, 'component', how = 'left_outer')

    cc = dict()
    for row in cc_all.collect():
        cc[row['id']] = {'cc_id' : row['component'], 'cc_size' : row['count']}

    return cc

_calculate_degree()

This internal method handles the logic of a property. It returns the degree of each node in the graph.

Source code in mercury/graph/core/graph.py
651
652
653
654
655
656
657
def _calculate_degree(self):
    """ This internal method handles the logic of a property. It returns the degree of each node in the graph."""

    if self._as_networkx is not None:
        return dict(self._as_networkx.degree())

    return self._fill_node_zeros({row['id']: row['degree'] for row in self.graphframe.degrees.collect()})

_calculate_edges_colnames()

This internal method returns the column names of the edges DataFrame.

Source code in mercury/graph/core/graph.py
772
773
774
775
776
777
778
779
780
def _calculate_edges_colnames(self):
    """ This internal method returns the column names of the edges DataFrame. """

    if self._as_networkx is not None:
        l = ['src', 'dst']
        l.extend(list(self._as_networkx.edges[list(self._as_networkx.edges.keys())[0]].keys()))
        return l

    return self.graphframe.edges.columns

_calculate_in_degree()

This internal method handles the logic of a property. It returns the in-degree of each node in the graph.

Source code in mercury/graph/core/graph.py
660
661
662
663
664
665
666
def _calculate_in_degree(self):
    """ This internal method handles the logic of a property. It returns the in-degree of each node in the graph."""

    if self._as_networkx is not None:
        return dict(self._as_networkx.in_degree())

    return self._fill_node_zeros({row['id']: row['inDegree'] for row in self.graphframe.inDegrees.collect()})

_calculate_nodes_colnames()

This internal method returns the column names of the nodes DataFrame.

Source code in mercury/graph/core/graph.py
760
761
762
763
764
765
766
767
768
769
def _calculate_nodes_colnames(self):
    """ This internal method returns the column names of the nodes DataFrame. """

    if self._as_networkx is not None:
        l = ['id']
        l.extend(list(self._as_networkx.nodes[list(self._as_networkx.nodes.keys())[0]].keys()))

        return l

    return self.graphframe.vertices.columns

_calculate_out_degree()

This internal method handles the logic of a property. It returns the out-degree of each node in the graph.

Source code in mercury/graph/core/graph.py
669
670
671
672
673
674
675
def _calculate_out_degree(self):
    """ This internal method handles the logic of a property. It returns the out-degree of each node in the graph."""

    if self._as_networkx is not None:
        return dict(self._as_networkx.out_degree())

    return self._fill_node_zeros({row['id']: row['outDegree'] for row in self.graphframe.outDegrees.collect()})

_calculate_pagerank()

This internal method handles the logic of a property. It returns the PageRank of each node in the graph as a Python dictionary.

Source code in mercury/graph/core/graph.py
717
718
719
720
721
722
723
724
725
726
def _calculate_pagerank(self):
    """
    This internal method handles the logic of a property. It returns the PageRank of each node in the graph as a Python dictionary.
    """
    if self._as_networkx is not None:
        return nx.pagerank(self._as_networkx)

    pr = self.graphframe.pageRank(resetProbability = 0.15, tol = 0.01).vertices

    return {row['id']: row['pagerank'] for row in pr.collect()}

_fill_node_zeros(d)

This internal method fills the nodes that are not in the dictionary with a zero value. This make the output obtained from graphframes consistent with the one from networkx.

Source code in mercury/graph/core/graph.py
678
679
680
681
682
683
684
685
686
687
def _fill_node_zeros(self, d):
    """
    This internal method fills the nodes that are not in the dictionary with a zero value. This make the output obtained from
    graphframes consistent with the one from networkx.
    """
    for node in self.nodes:
        if node['id'] not in d:
            d[node['id']] = 0

    return d

_from_dataframe(edges, nodes, keys)

This internal method extends the constructor to accept a pyspark DataFrame as input.

It takes the constructor arguments and does not return anything. It sets the internal state of the object.

Source code in mercury/graph/core/graph.py
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
def _from_dataframe(self, edges, nodes, keys):
    """ This internal method extends the constructor to accept a pyspark DataFrame as input.

    It takes the constructor arguments and does not return anything. It sets the internal state of the object.
    """
    if keys is None:
        src = 'src'
        dst = 'dst'
        id  = 'id'
        weight = 'weight'
        directed = True
    else:
        src = keys.get('src', 'src')
        dst = keys.get('dst', 'dst')
        id  = keys.get('id', 'id')
        weight = keys.get('weight', 'weight')
        directed = keys.get('directed', True)

    edges = edges.withColumnRenamed(src, 'src').withColumnRenamed(dst, 'dst')

    if weight in edges.columns:
        edges = edges.withColumnRenamed(weight, 'weight')

    if nodes is not None:
        nodes = nodes.withColumnRenamed(id, 'id').dropDuplicates(['id'])
    else:
        src_nodes = edges.select('src').distinct().withColumnRenamed('src', 'id')
        dst_nodes = edges.select('dst').distinct().withColumnRenamed('dst', 'id')
        nodes = src_nodes.union(dst_nodes).distinct()

    g = SparkInterface().graphframes.GraphFrame(nodes, edges)

    if not directed:
        edges = g.edges

        other_columns = [col for col in edges.columns if col not in ('src', 'dst')]
        reverse_edges = edges.select(edges['dst'].alias('src'), edges['src'].alias('dst'), *other_columns)
        all_edges     = edges.union(reverse_edges).distinct()

        g = SparkInterface().graphframes.GraphFrame(nodes, all_edges)

    self._from_graphframes(g, directed)

_from_graphframes(graph, directed=True)

This internal method extends the constructor to accept a graphframes graph as input.

It takes the constructor arguments and does not return anything. It sets the internal state of the object.

Source code in mercury/graph/core/graph.py
589
590
591
592
593
594
595
596
597
598
def _from_graphframes(self, graph, directed = True):
    """ This internal method extends the constructor to accept a graphframes graph as input.

    It takes the constructor arguments and does not return anything. It sets the internal state of the object.
    """
    self._as_graphframe = graph
    self._number_of_nodes = graph.vertices.count()
    self._number_of_edges = graph.edges.count()
    self._is_directed = directed
    self._is_weighted = 'weight' in self.edges_colnames

_from_networkx(graph)

This internal method extends the constructor to accept a networkx graph as input.

It takes the constructor arguments and does not return anything. It sets the internal state of the object.

Source code in mercury/graph/core/graph.py
577
578
579
580
581
582
583
584
585
586
def _from_networkx(self, graph):
    """ This internal method extends the constructor to accept a networkx graph as input.

    It takes the constructor arguments and does not return anything. It sets the internal state of the object.
    """
    self._as_networkx = graph
    self._number_of_nodes = len(graph.nodes)
    self._number_of_edges = len(graph.edges)
    self._is_directed = nx.is_directed(graph)
    self._is_weighted = 'weight' in self.edges_colnames

_from_pandas(edges, nodes, keys)

This internal method extends the constructor to accept a pandas DataFrame as input.

It takes the constructor arguments and does not return anything. It sets the internal state of the object.

Source code in mercury/graph/core/graph.py
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
def _from_pandas(self, edges, nodes, keys):
    """ This internal method extends the constructor to accept a pandas DataFrame as input.

    It takes the constructor arguments and does not return anything. It sets the internal state of the object.
    """
    if keys is None:
        src = 'src'
        dst = 'dst'
        id  = 'id'
        weight = 'weight'
        directed = True
    else:
        src = keys.get('src', 'src')
        dst = keys.get('dst', 'dst')
        id  = keys.get('id', 'id')
        weight = keys.get('weight', 'weight')
        directed = keys.get('directed', True)

    if directed:
        g = nx.DiGraph()
    else:
        g = nx.Graph()

    if weight in edges.columns:
        edges = edges.rename(columns = {weight: 'weight'})

    for _, row in edges.iterrows():
        attr = row.drop([src, dst]).to_dict()
        g.add_edge(row[src], row[dst], **attr)

    if nodes is not None:
        for _, row in nodes.iterrows():
            attr = row.drop([id]).to_dict()
            g.add_node(row[id], **attr)

    self._from_networkx(g)

_to_dgl()

This internal method handles the logic of a property. It returns the dgl graph that already exists or converts it from the networkx graph if not.

Source code in mercury/graph/core/graph.py
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
def _to_dgl(self):
    """ This internal method handles the logic of a property. It returns the dgl graph that already exists
    or converts it from the networkx graph if not."""

    if dgl_installed:
        dgl = SparkInterface().dgl

        edge_attrs = [c for c in self.edges_colnames if c not in ['src', 'dst']]
        if len(edge_attrs) == 0:
            edge_attrs = None

        node_attrs = [c for c in self.nodes_colnames if c not in ['id']]
        if len(node_attrs) == 0:
            node_attrs = None

        self._as_dgl = dgl.from_networkx(self.networkx, edge_attrs = edge_attrs, node_attrs = node_attrs)

    return self._as_dgl

_to_graphframe()

This internal method handles the logic of a property. It returns the graphframes graph that already exists or converts it from the networkx graph if not.

Source code in mercury/graph/core/graph.py
621
622
623
624
625
626
627
628
def _to_graphframe(self):
    """ This internal method handles the logic of a property. It returns the graphframes graph that already exists
    or converts it from the networkx graph if not."""

    nodes = self.nodes_as_dataframe()
    edges = self.edges_as_dataframe()

    return SparkInterface().graphframes.GraphFrame(nodes, edges)

_to_networkx()

This internal method handles the logic of a property. It returns the networkx graph that already exists or converts it from the graphframes graph if not.

Source code in mercury/graph/core/graph.py
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
def _to_networkx(self):
    """ This internal method handles the logic of a property. It returns the networkx graph that already exists
    or converts it from the graphframes graph if not."""

    if self._is_directed:
        g = nx.DiGraph()
    else:
        g = nx.Graph()

    for _, row in self.edges_as_pandas().iterrows():
        attr = row.drop(['src', 'dst']).to_dict()
        g.add_edge(row['src'], row['dst'], **attr)

    for _, row in self.nodes_as_pandas().iterrows():
        attr = row.drop(['id']).to_dict()
        g.add_node(row['id'], **attr)

    return g

edges_as_dataframe()

Returns the edges as a pyspark DataFrame.

If the graph is represented as a graphframes graph, the edges are extracted from it. Otherwise, the edges are converted from the pandas DataFrame representation. The columns used as the source and destination nodes are always named 'src' and 'dst', respectively, regardless of the original column names passed to the constructor.

Source code in mercury/graph/core/graph.py
481
482
483
484
485
486
487
488
489
490
491
492
def edges_as_dataframe(self):
    """
    Returns the edges as a pyspark DataFrame.

    If the graph is represented as a graphframes graph, the edges are extracted from it. Otherwise, the edges are converted from the
    pandas DataFrame representation. The columns used as the source and destination nodes are always named 'src' and 'dst',
    respectively, regardless of the original column names passed to the constructor.
    """
    if self._as_graphframe is not None:
        return self._as_graphframe.edges

    return SparkInterface().spark.createDataFrame(self.edges_as_pandas())

edges_as_pandas()

Returns the edges as a pandas DataFrame.

If the graph is represented as a networkx graph, the edges are extracted from it. Otherwise, the graphframes graph will be used. This dataset may differ from possible pandas DataFrame passed to the constructor in the column names and order. The columns used as the source and destination nodes are always named 'src' and 'dst', respectively.

Source code in mercury/graph/core/graph.py
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
def edges_as_pandas(self):
    """
    Returns the edges as a pandas DataFrame.

    If the graph is represented as a networkx graph, the edges are extracted from it. Otherwise, the graphframes graph will be used.
    This dataset may differ from possible pandas DataFrame passed to the constructor in the column names and order. The columns used
    as the source and destination nodes are always named 'src' and 'dst', respectively.
    """
    if self._as_networkx is not None:
        edges_data = self._as_networkx.edges(data = True)
        edges_df   = pd.DataFrame([(src, dst, attr) for src, dst, attr in edges_data], columns = ['src', 'dst', 'attributes'])

        attrs_df   = pd.json_normalize(edges_df['attributes'])

        return pd.concat([edges_df.drop('attributes', axis = 1), attrs_df], axis = 1)

    return self.graphframe.edges.toPandas()

nodes_as_dataframe()

Returns the nodes as a pyspark DataFrame.

If the graph is represented as a graphframes graph, the nodes are extracted from it. Otherwise, the nodes are converted from the pandas DataFrame representation. The column used as the node id is always named 'id', regardless of the original column name passed to the constructor.

Source code in mercury/graph/core/graph.py
467
468
469
470
471
472
473
474
475
476
477
478
def nodes_as_dataframe(self):
    """
    Returns the nodes as a pyspark DataFrame.

    If the graph is represented as a graphframes graph, the nodes are extracted from it. Otherwise, the nodes are converted from the
    pandas DataFrame representation. The column used as the node id is always named 'id', regardless of the original column name passed
    to the constructor.
    """
    if self._as_graphframe is not None:
        return self._as_graphframe.vertices

    return SparkInterface().spark.createDataFrame(self.nodes_as_pandas())

nodes_as_pandas()

Returns the nodes as a pandas DataFrame.

If the graph is represented as a networkx graph, the nodes are extracted from it. Otherwise, the graphframes graph will be used. This dataset may differ from possible pandas DataFrame passed to the constructor in the column names and order. The column used as the node id is always named 'id'.

Source code in mercury/graph/core/graph.py
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
def nodes_as_pandas(self):
    """
    Returns the nodes as a pandas DataFrame.

    If the graph is represented as a networkx graph, the nodes are extracted from it. Otherwise, the graphframes graph will be used.
    This dataset may differ from possible pandas DataFrame passed to the constructor in the column names and order. The column used
    as the node id is always named 'id'.
    """
    if self._as_networkx is not None:
        nodes_data = self._as_networkx.nodes(data = True)
        nodes_df   = pd.DataFrame([(node, attr) for node, attr in nodes_data], columns = ['id', 'attributes'])

        attrs_df = pd.json_normalize(nodes_df['attributes'])

        return pd.concat([nodes_df.drop('attributes', axis = 1), attrs_df], axis = 1)

    return self.graphframe.vertices.toPandas()

mercury.graph.core.SparkInterface(config=None, session=None)

A class that provides an interface for interacting with Apache Spark, graphframes and dgl.

Attributes:

Name Type Description
_spark_session SparkSession

The shared Spark session.

_graphframes module

The shared graphframes namespace.

Methods:

Name Description
_create_spark_session

Creates a Spark session.

spark

Property that returns the shared Spark session.

pyspark

Property that returns the pyspark namespace.

graphframes

Property that returns the shared graphframes namespace.

dgl

Property that returns the shared dgl namespace.

read_csv

Reads a CSV file into a DataFrame.

read_parquet

Reads a Parquet file into a DataFrame.

read_json

Reads a JSON file into a DataFrame.

read_text

Reads a text file into a DataFrame.

read

Reads a file into a DataFrame.

sql

Executes a SQL query.

udf

Registers a user-defined function (UDF).

stop

Stops the Spark session.

Parameters:

Name Type Description Default
config dict

A dictionary of Spark configuration options. If not provided, the configuration in the global variable default_spark_config will be used.

None
Source code in mercury/graph/core/spark_interface.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def __init__(self, config=None, session=None):
    if SparkInterface._spark_session is None:
        if session is not None:
            SparkInterface._spark_session = session
        else:
            SparkInterface._spark_session = self._create_spark_session(config)
            # Set checkpoint directory
            SparkInterface._spark_session.sparkContext.setCheckpointDir(".checkpoint")

    if SparkInterface._graphframes is None and graphframes_installed:
        SparkInterface._graphframes = gf

    if SparkInterface._dgl is None and dgl_installed:
        SparkInterface._dgl = dgl