From dfc3a140bc320a0f506af7e328122cb2c86a9bc7 Mon Sep 17 00:00:00 2001 From: katarinasupe Date: Fri, 28 Jul 2023 15:26:26 +0200 Subject: [PATCH 1/3] Add support for Memgraph --- ChangeLog.md | 1 + README.md | 26 + additional-databases/memgraph/README.md | 49 + .../configuration/generate_config.py | 429 ++- .../configuration/get_config.py | 137 +- src/graph_notebook/magics/graph_magic.py | 3349 +++++++++++------ src/graph_notebook/neptune/client.py | 979 +++-- ...-Supply-Chain-Analysis-with-Memgraph.ipynb | 655 ++++ test/integration/IntegrationTest.py | 78 +- test/integration/iam/ml/__init__.py | 44 +- .../test_open_cypher_graph_notebook.py | 89 +- 11 files changed, 4060 insertions(+), 1776 deletions(-) create mode 100644 additional-databases/memgraph/README.md create mode 100644 src/graph_notebook/notebooks/01-Getting-Started/06-Supply-Chain-Analysis-with-Memgraph.ipynb diff --git a/ChangeLog.md b/ChangeLog.md index ca870098..bd326c93 100644 --- a/ChangeLog.md +++ b/ChangeLog.md @@ -4,6 +4,7 @@ Starting with v1.31.6, this file will contain a record of major features and upd ## Upcoming - Added `--explain-type` option to `%%gremlin` ([Link to PR](https://github.com/aws/graph-notebook/pull/503)) +- Added Memgraph as an additional graph database and the supply chain analysis notebook ([Link to PR]()) ## Release 3.8.2 (June 5, 2023) - New Sample Applications - Healthcare and Life Sciences notebooks ([Link to PR](https://github.com/aws/graph-notebook/pull/484)) diff --git a/README.md b/README.md index b9695c1c..6d31b6cf 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ Instructions for connecting to the following graph databases: | [Blazegraph](#blazegraph) | RDF | SPARQL | |[Amazon Neptune](#amazon-neptune)| property graph or RDF | Gremlin or SPARQL | | [Neo4J](#neo4j) | property graph | Cypher | +| [Memgraph](#memgraph) | property graph | Cypher | We encourage others to contribute configurations they find useful. There is an [`additional-databases`](https://github.com/aws/graph-notebook/blob/main/additional-databases) folder where more information can be found. @@ -301,6 +302,31 @@ Ensure that you also specify the `%%oc bolt` option when submitting queries to t To setup a new local Neo4J Desktop database for use with the graph notebook, check out the [Neo4J Desktop User Interface Guide](https://neo4j.com/developer/neo4j-desktop/). +### Memgraph + +Change the configuration using `%%graph_notebook_config` and modify the fields for `host` and `port`, `ssl`. + +After local setup of Memgraph is complete, set the following configuration to connect from graph-notebook: + +``` +%%graph_notebook_config +{ + "host": "localhost", + "port": 7687, + "ssl": false +} +``` + +Ensure that you specify the `%%oc bolt` option when submitting queries to the Bolt endpoint. For example, a correct way of running a Cypher query via Bolt protocol is: + +``` +%%oc bolt +MATCH (n) +RETURN count(n) +``` + +For more details on how to run Memgraph, refer to its [notebook guide](./additional-databases/memgraph/README.md). + ## Building From Source A pre-release distribution can be built from the graph-notebook repository via the following steps: diff --git a/additional-databases/memgraph/README.md b/additional-databases/memgraph/README.md new file mode 100644 index 00000000..9556a02c --- /dev/null +++ b/additional-databases/memgraph/README.md @@ -0,0 +1,49 @@ +## Connecting graph notebook to Memgraph Bolt Endpoint + +[Memgraph](https://memgraph.com/) is an open-source in-memory graph database built for highly performant and advanced analytical insights. Memgraph is Neo4J Bolt protocol compatible and it uses the standardized Cypher query language. + +For a quick start, run the following command in your terminal to start Memgraph Platform in a Docker container: + +``` +docker run -it -p 7687:7687 -p 7444:7444 -p 3000:3000 -e MEMGRAPH="--bolt-server-name-for-init=Neo4j/" memgraph/memgraph-platform +``` + +The above command started Memgraph database, MAGE (graph algorithms library) and Memgraph Lab (visual user interface). For additional instructions on setting up and running Memgraph locally, refer to the [Memgraph documentation](https://memgraph.com/docs/memgraph/installation). Connection to the graph notebook works if the `--bolt-server-name-for-init` setting is modified. For more information on changing configuration settings, refer to our [how-to guide](https://memgraph.com/docs/memgraph/how-to-guides/config-logs). + + +After local setup of Memgraph is complete, set the following configuration to connect from graph-notebook: + +``` +%%graph_notebook_config +{ + "host": "localhost", + "port": 7687, + "ssl": false +} +``` + +If you set up an authentication on your Memgraph instance, you can provide login details via configuration. For example, if you created user `username` identified by `password`, then the following configuration is the correct one: + +%%graph_notebook_config +{ + "host": "localhost", + "port": 7687, + "ssl": false, + "memgraph": { + "username": "username", + "password": "password", + "auth": true + } +} + +To learn how to manage users in Memgraph, refer to [Memgraph documentation](https://memgraph.com/docs/memgraph/reference-guide/users). + +You can query Memgraph via Bolt protocol which was designed for efficient communication with graph databases. Memgraph supports versions 1 and 4 of the protocol. Ensure that you specify the `%%oc bolt` option when submitting queries to the Bolt endpoint. For example, a correct way of running a Cypher query via Bolt protocol is: + +``` +%%oc bolt +MATCH (n) +RETURN count(n) +``` + +Another way of ensuring that Memgraph is running, head to `localhost:3000` and check out Memgraph Lab, a visual user interface. You can see node and relationship count there, explore, query and visualize data. If you get stuck and have more questions, [let's talk at Memgraph Discord community](https://www.discord.gg/memgraph). diff --git a/src/graph_notebook/configuration/generate_config.py b/src/graph_notebook/configuration/generate_config.py index d8720952..d54cd497 100644 --- a/src/graph_notebook/configuration/generate_config.py +++ b/src/graph_notebook/configuration/generate_config.py @@ -8,12 +8,26 @@ import os from enum import Enum -from graph_notebook.neptune.client import SPARQL_ACTION, DEFAULT_PORT, DEFAULT_REGION, DEFAULT_GREMLIN_SERIALIZER, \ - DEFAULT_GREMLIN_TRAVERSAL_SOURCE, DEFAULT_NEO4J_USERNAME, DEFAULT_NEO4J_PASSWORD, DEFAULT_NEO4J_DATABASE, \ - NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, false_str_variants, \ - GRAPHSONV3_VARIANTS, GRAPHSONV2_VARIANTS, GRAPHBINARYV1_VARIANTS - -DEFAULT_CONFIG_LOCATION = os.path.expanduser('~/graph_notebook_config.json') +from graph_notebook.neptune.client import ( + SPARQL_ACTION, + DEFAULT_PORT, + DEFAULT_REGION, + DEFAULT_GREMLIN_SERIALIZER, + DEFAULT_GREMLIN_TRAVERSAL_SOURCE, + DEFAULT_NEO4J_USERNAME, + DEFAULT_NEO4J_PASSWORD, + DEFAULT_NEO4J_DATABASE, + DEFAULT_MEMGRAPH_USERNAME, + DEFAULT_MEMGRAPH_PASSWORD, + NEPTUNE_CONFIG_HOST_IDENTIFIERS, + is_allowed_neptune_host, + false_str_variants, + GRAPHSONV3_VARIANTS, + GRAPHSONV2_VARIANTS, + GRAPHBINARYV1_VARIANTS, +) + +DEFAULT_CONFIG_LOCATION = os.path.expanduser("~/graph_notebook_config.json") class AuthModeEnum(Enum): @@ -29,16 +43,18 @@ class SparqlSection(object): Used for sparql-specific settings in a notebook's configuration """ - def __init__(self, path: str = SPARQL_ACTION, endpoint_prefix: str = ''): + def __init__(self, path: str = SPARQL_ACTION, endpoint_prefix: str = ""): """ :param path: used to specify the base-path of the api being connected to do get to its corresponding sparql endpoint. """ - if endpoint_prefix != '': - print('endpoint_prefix has been deprecated and will be removed in version 2.0.20 or greater.') - if path == '': - path = f'{endpoint_prefix}/sparql' + if endpoint_prefix != "": + print( + "endpoint_prefix has been deprecated and will be removed in version 2.0.20 or greater." + ) + if path == "": + path = f"{endpoint_prefix}/sparql" self.path = path @@ -51,8 +67,13 @@ class GremlinSection(object): Used for gremlin-specific settings in a notebook's configuration """ - def __init__(self, traversal_source: str = '', username: str = '', password: str = '', - message_serializer: str = ''): + def __init__( + self, + traversal_source: str = "", + username: str = "", + password: str = "", + message_serializer: str = "", + ): """ :param traversal_source: used to specify the traversal source for a Gremlin traversal, in the case that we are connected to an endpoint that can access multiple graphs. @@ -61,24 +82,25 @@ def __init__(self, traversal_source: str = '', username: str = '', password: str :param message_serializer: used to specify a serializer for encoding the data to and from Gremlin Server. """ - if traversal_source == '': + if traversal_source == "": traversal_source = DEFAULT_GREMLIN_TRAVERSAL_SOURCE serializer_lower = message_serializer.lower() - if serializer_lower == '': + if serializer_lower == "": message_serializer = DEFAULT_GREMLIN_SERIALIZER elif serializer_lower in GRAPHSONV3_VARIANTS: - message_serializer = 'graphsonv3' + message_serializer = "graphsonv3" elif serializer_lower in GRAPHSONV2_VARIANTS: - message_serializer = 'graphsonv2' + message_serializer = "graphsonv2" elif serializer_lower in GRAPHBINARYV1_VARIANTS: - message_serializer = 'graphbinaryv1' + message_serializer = "graphbinaryv1" else: - print(f'Invalid Gremlin serializer specified, defaulting to graphsonv3. ' - f'Valid serializers: [graphsonv3, graphsonv2, graphbinaryv1].') + print( + f"Invalid Gremlin serializer specified, defaulting to graphsonv3. " + f"Valid serializers: [graphsonv3, graphsonv2, graphbinaryv1]." + ) message_serializer = DEFAULT_GREMLIN_SERIALIZER - self.traversal_source = traversal_source self.username = username self.password = password @@ -93,36 +115,82 @@ class Neo4JSection(object): Used for Neo4J-specific settings in a notebook's configuration """ - def __init__(self, username: str = '', password: str = '', auth: bool = True, database: str = ''): + def __init__( + self, + username: str = "", + password: str = "", + auth: bool = True, + database: str = "", + ): """ :param username: login user for the Neo4J endpoint :param password: login password for the Neo4J endpoint """ - if username == '': + if username == "": username = DEFAULT_NEO4J_USERNAME - if password == '': + if password == "": password = DEFAULT_NEO4J_PASSWORD - if database == '': + if database == "": database = DEFAULT_NEO4J_DATABASE self.username = username self.password = password - self.auth = False if auth in [False, 'False', 'false', 'FALSE'] else True + self.auth = False if auth in [False, "False", "false", "FALSE"] else True self.database = database def to_dict(self): return self.__dict__ +class MemgraphSection(object): + """ + Used for Memgraph-specific settings in a notebook's configuration + """ + + def __init__( + self, + username: str = "", + password: str = "", + auth: bool = False, + ): + """ + :param username: login user for the Memgraph endpoint + :param password: login password for the Memgraph endpoint + :param auth: authentication switch for the Memgraph endpoint + """ + + if username == "": + username = DEFAULT_MEMGRAPH_USERNAME + if password == "": + password = DEFAULT_MEMGRAPH_PASSWORD + + self.username = username + self.password = password + self.auth = True if auth in [True, "True", "true", "TRUE"] else False + + def to_dict(self): + return self.__dict__ + + class Configuration(object): - def __init__(self, host: str, port: int, - auth_mode: AuthModeEnum = DEFAULT_AUTH_MODE, - load_from_s3_arn='', ssl: bool = True, ssl_verify: bool = True, aws_region: str = DEFAULT_REGION, - proxy_host: str = '', proxy_port: int = DEFAULT_PORT, - sparql_section: SparqlSection = None, gremlin_section: GremlinSection = None, - neo4j_section: Neo4JSection = None, - neptune_hosts: list = NEPTUNE_CONFIG_HOST_IDENTIFIERS): + def __init__( + self, + host: str, + port: int, + auth_mode: AuthModeEnum = DEFAULT_AUTH_MODE, + load_from_s3_arn="", + ssl: bool = True, + ssl_verify: bool = True, + aws_region: str = DEFAULT_REGION, + proxy_host: str = "", + proxy_port: int = DEFAULT_PORT, + sparql_section: SparqlSection = None, + gremlin_section: GremlinSection = None, + neo4j_section: Neo4JSection = None, + memgraph_section: MemgraphSection = None, + neptune_hosts: list = NEPTUNE_CONFIG_HOST_IDENTIFIERS, + ): self._host = host.strip() self.port = port self.ssl = ssl @@ -131,8 +199,11 @@ def __init__(self, host: str, port: int, self.proxy_port = proxy_port self.sparql = sparql_section if sparql_section is not None else SparqlSection() - is_neptune_host = is_allowed_neptune_host(hostname=self.host, host_allowlist=neptune_hosts) \ - or is_allowed_neptune_host(hostname=self.proxy_host, host_allowlist=neptune_hosts) + is_neptune_host = is_allowed_neptune_host( + hostname=self.host, host_allowlist=neptune_hosts + ) or is_allowed_neptune_host( + hostname=self.proxy_host, host_allowlist=neptune_hosts + ) if is_neptune_host: self.is_neptune_config = True self.auth_mode = auth_mode @@ -140,10 +211,16 @@ def __init__(self, host: str, port: int, self.aws_region = aws_region self.gremlin = GremlinSection() self.neo4j = Neo4JSection() + self.memgraph = MemgraphSection() else: self.is_neptune_config = False - self.gremlin = gremlin_section if gremlin_section is not None else GremlinSection() + self.gremlin = ( + gremlin_section if gremlin_section is not None else GremlinSection() + ) self.neo4j = neo4j_section if neo4j_section is not None else Neo4JSection() + self.memgraph = ( + memgraph_section if memgraph_section is not None else MemgraphSection() + ) @property def host(self): @@ -164,109 +241,235 @@ def proxy_host(self, value: str): def to_dict(self) -> dict: if self.is_neptune_config: return { - 'host': self.host, - 'port': self.port, - 'proxy_host': self.proxy_host, - 'proxy_port': self.proxy_port, - 'auth_mode': self.auth_mode.value, - 'load_from_s3_arn': self.load_from_s3_arn, - 'ssl': self.ssl, - 'ssl_verify': self.ssl_verify, - 'aws_region': self.aws_region, - 'sparql': self.sparql.to_dict(), - 'gremlin': self.gremlin.to_dict(), - 'neo4j': self.neo4j.to_dict() + "host": self.host, + "port": self.port, + "proxy_host": self.proxy_host, + "proxy_port": self.proxy_port, + "auth_mode": self.auth_mode.value, + "load_from_s3_arn": self.load_from_s3_arn, + "ssl": self.ssl, + "ssl_verify": self.ssl_verify, + "aws_region": self.aws_region, + "sparql": self.sparql.to_dict(), + "gremlin": self.gremlin.to_dict(), + "neo4j": self.neo4j.to_dict(), + "memgraph": self.memgraph.to_dict(), } else: return { - 'host': self.host, - 'port': self.port, - 'proxy_host': self.proxy_host, - 'proxy_port': self.proxy_port, - 'ssl': self.ssl, - 'ssl_verify': self.ssl_verify, - 'sparql': self.sparql.to_dict(), - 'gremlin': self.gremlin.to_dict(), - 'neo4j': self.neo4j.to_dict() + "host": self.host, + "port": self.port, + "proxy_host": self.proxy_host, + "proxy_port": self.proxy_port, + "ssl": self.ssl, + "ssl_verify": self.ssl_verify, + "sparql": self.sparql.to_dict(), + "gremlin": self.gremlin.to_dict(), + "neo4j": self.neo4j.to_dict(), + "memgraph": self.memgraph.to_dict(), } def write_to_file(self, file_path=DEFAULT_CONFIG_LOCATION): data = self.to_dict() - with open(file_path, mode='w+') as file: + with open(file_path, mode="w+") as file: json.dump(data, file, indent=2) return -def generate_config(host, port, auth_mode: AuthModeEnum = AuthModeEnum.DEFAULT, ssl: bool = True, - ssl_verify: bool = True, load_from_s3_arn='', - aws_region: str = DEFAULT_REGION, proxy_host: str = '', proxy_port: int = DEFAULT_PORT, - sparql_section: SparqlSection = SparqlSection(), gremlin_section: GremlinSection = GremlinSection(), - neo4j_section=Neo4JSection(), neptune_hosts: list = NEPTUNE_CONFIG_HOST_IDENTIFIERS): +def generate_config( + host, + port, + auth_mode: AuthModeEnum = AuthModeEnum.DEFAULT, + ssl: bool = True, + ssl_verify: bool = True, + load_from_s3_arn="", + aws_region: str = DEFAULT_REGION, + proxy_host: str = "", + proxy_port: int = DEFAULT_PORT, + sparql_section: SparqlSection = SparqlSection(), + gremlin_section: GremlinSection = GremlinSection(), + neo4j_section=Neo4JSection(), + memgraph_section=MemgraphSection(), + neptune_hosts: list = NEPTUNE_CONFIG_HOST_IDENTIFIERS, +): use_ssl = False if ssl in false_str_variants else True verify_ssl = False if ssl_verify in false_str_variants else True - c = Configuration(host, port, auth_mode, load_from_s3_arn, use_ssl, verify_ssl, aws_region, proxy_host, proxy_port, - sparql_section, gremlin_section, neo4j_section, neptune_hosts) + c = Configuration( + host, + port, + auth_mode, + load_from_s3_arn, + use_ssl, + verify_ssl, + aws_region, + proxy_host, + proxy_port, + sparql_section, + gremlin_section, + neo4j_section, + memgraph_section, + neptune_hosts, + ) return c def generate_default_config(): - c = generate_config('change-me', 8182, AuthModeEnum.DEFAULT, True, True, '', DEFAULT_REGION) + c = generate_config( + "change-me", 8182, AuthModeEnum.DEFAULT, True, True, "", DEFAULT_REGION + ) return c if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--host", help="the host url to form a connection with", required=True) - parser.add_argument("--port", help="the port to use when creating a connection", default=8182) - parser.add_argument("--auth_mode", default=AuthModeEnum.DEFAULT.value, - help="type of authentication the cluster being connected to is using. Can be DEFAULT or IAM") - parser.add_argument("--ssl", - help="whether to make connections to the created endpoint with ssl or not [True|False]", - default=True) - parser.add_argument("--ssl-verify", - help="whether to verify the server's TLS certificate or not [True|False]", - default=True) + parser.add_argument( + "--host", help="the host url to form a connection with", required=True + ) + parser.add_argument( + "--port", help="the port to use when creating a connection", default=8182 + ) + parser.add_argument( + "--auth_mode", + default=AuthModeEnum.DEFAULT.value, + help="type of authentication the cluster being connected to is using. Can be DEFAULT or IAM", + ) + parser.add_argument( + "--ssl", + help="whether to make connections to the created endpoint with ssl or not [True|False]", + default=True, + ) + parser.add_argument( + "--ssl-verify", + help="whether to verify the server's TLS certificate or not [True|False]", + default=True, + ) # TODO: Remove this after we fix the LC script in S3. - parser.add_argument("--iam_credentials_provider", default='ROLE', - help="The mode used to obtain credentials for IAM Authentication. Can be ROLE or ENV") - parser.add_argument("--config_destination", help="location to put generated config", - default=DEFAULT_CONFIG_LOCATION) - parser.add_argument("--load_from_s3_arn", help="arn of role to use for bulk loader", default='') - parser.add_argument("--aws_region", help="aws region your ml cluster is in.", default=DEFAULT_REGION) - parser.add_argument("--proxy_host", help="the proxy host url to route a connection through", default='') - parser.add_argument("--proxy_port", help="the proxy port to use when creating proxy connection", default=8182) - parser.add_argument("--neptune_hosts", nargs="*", - help="list of host snippets to use for identifying neptune endpoints", - default=NEPTUNE_CONFIG_HOST_IDENTIFIERS) - parser.add_argument("--sparql_path", help="the namespace path to append to the SPARQL endpoint", - default=SPARQL_ACTION) - parser.add_argument("--gremlin_traversal_source", help="the traversal source to use for Gremlin queries", - default=DEFAULT_GREMLIN_TRAVERSAL_SOURCE) - parser.add_argument("--gremlin_username", help="the username to use when creating Gremlin connections", default='') - parser.add_argument("--gremlin_password", help="the password to use when creating Gremlin connections", default='') - parser.add_argument("--gremlin_serializer", - help="the serializer to use as the encoding format when creating Gremlin connections", - default=DEFAULT_GREMLIN_SERIALIZER) - parser.add_argument("--neo4j_username", help="the username to use for Neo4J connections", - default=DEFAULT_NEO4J_USERNAME) - parser.add_argument("--neo4j_password", help="the password to use for Neo4J connections", - default=DEFAULT_NEO4J_PASSWORD) - parser.add_argument("--neo4j_auth", help="whether to use auth for Neo4J connections or not [True|False]", - default=True) - parser.add_argument("--neo4j_database", help="the name of the database to use for Neo4J", - default=DEFAULT_NEO4J_DATABASE) + parser.add_argument( + "--iam_credentials_provider", + default="ROLE", + help="The mode used to obtain credentials for IAM Authentication. Can be ROLE or ENV", + ) + parser.add_argument( + "--config_destination", + help="location to put generated config", + default=DEFAULT_CONFIG_LOCATION, + ) + parser.add_argument( + "--load_from_s3_arn", help="arn of role to use for bulk loader", default="" + ) + parser.add_argument( + "--aws_region", help="aws region your ml cluster is in.", default=DEFAULT_REGION + ) + parser.add_argument( + "--proxy_host", + help="the proxy host url to route a connection through", + default="", + ) + parser.add_argument( + "--proxy_port", + help="the proxy port to use when creating proxy connection", + default=8182, + ) + parser.add_argument( + "--neptune_hosts", + nargs="*", + help="list of host snippets to use for identifying neptune endpoints", + default=NEPTUNE_CONFIG_HOST_IDENTIFIERS, + ) + parser.add_argument( + "--sparql_path", + help="the namespace path to append to the SPARQL endpoint", + default=SPARQL_ACTION, + ) + parser.add_argument( + "--gremlin_traversal_source", + help="the traversal source to use for Gremlin queries", + default=DEFAULT_GREMLIN_TRAVERSAL_SOURCE, + ) + parser.add_argument( + "--gremlin_username", + help="the username to use when creating Gremlin connections", + default="", + ) + parser.add_argument( + "--gremlin_password", + help="the password to use when creating Gremlin connections", + default="", + ) + parser.add_argument( + "--gremlin_serializer", + help="the serializer to use as the encoding format when creating Gremlin connections", + default=DEFAULT_GREMLIN_SERIALIZER, + ) + parser.add_argument( + "--neo4j_username", + help="the username to use for Neo4J connections", + default=DEFAULT_NEO4J_USERNAME, + ) + parser.add_argument( + "--neo4j_password", + help="the password to use for Neo4J connections", + default=DEFAULT_NEO4J_PASSWORD, + ) + parser.add_argument( + "--neo4j_auth", + help="whether to use auth for Neo4J connections or not [True|False]", + default=True, + ) + parser.add_argument( + "--neo4j_database", + help="the name of the database to use for Neo4J", + default=DEFAULT_NEO4J_DATABASE, + ) + parser.add_argument( + "--memgraph_username", + help="the username to use for Memgraph connections", + default=DEFAULT_MEMGRAPH_USERNAME, + ) + parser.add_argument( + "--memgraph_password", + help="the password to use for Memgraph connections", + default=DEFAULT_MEMGRAPH_PASSWORD, + ) + parser.add_argument( + "--auth_auth", + help="whether to use auth for Memgraph connections or not [True|False]", + default=True, + ) args = parser.parse_args() - auth_mode_arg = args.auth_mode if args.auth_mode != '' else AuthModeEnum.DEFAULT.value - config = generate_config(args.host, int(args.port), AuthModeEnum(auth_mode_arg), args.ssl, args.ssl_verify, - args.load_from_s3_arn, args.aws_region, args.proxy_host, int(args.proxy_port), - SparqlSection(args.sparql_path, ''), - GremlinSection(args.gremlin_traversal_source, args.gremlin_username, - args.gremlin_password, args.gremlin_serializer), - Neo4JSection(args.neo4j_username, args.neo4j_password, - args.neo4j_auth, args.neo4j_database), - args.neptune_hosts) + auth_mode_arg = ( + args.auth_mode if args.auth_mode != "" else AuthModeEnum.DEFAULT.value + ) + config = generate_config( + args.host, + int(args.port), + AuthModeEnum(auth_mode_arg), + args.ssl, + args.ssl_verify, + args.load_from_s3_arn, + args.aws_region, + args.proxy_host, + int(args.proxy_port), + SparqlSection(args.sparql_path, ""), + GremlinSection( + args.gremlin_traversal_source, + args.gremlin_username, + args.gremlin_password, + args.gremlin_serializer, + ), + Neo4JSection( + args.neo4j_username, + args.neo4j_password, + args.neo4j_auth, + args.neo4j_database, + ), + MemgraphSection( + args.memgraph_username, args.memgraph_password, args.memgraph_auth + ), + args.neptune_hosts, + ) config.write_to_file(args.config_destination) exit(0) diff --git a/src/graph_notebook/configuration/get_config.py b/src/graph_notebook/configuration/get_config.py index 5a8dad1e..696a6473 100644 --- a/src/graph_notebook/configuration/get_config.py +++ b/src/graph_notebook/configuration/get_config.py @@ -5,57 +5,128 @@ import json -from graph_notebook.configuration.generate_config import DEFAULT_CONFIG_LOCATION, Configuration, AuthModeEnum, \ - SparqlSection, GremlinSection, Neo4JSection -from graph_notebook.neptune.client import NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, false_str_variants, \ - DEFAULT_NEO4J_USERNAME, DEFAULT_NEO4J_PASSWORD, DEFAULT_NEO4J_DATABASE +from graph_notebook.configuration.generate_config import ( + DEFAULT_CONFIG_LOCATION, + Configuration, + AuthModeEnum, + SparqlSection, + GremlinSection, + Neo4JSection, + MemgraphSection, +) +from graph_notebook.neptune.client import ( + NEPTUNE_CONFIG_HOST_IDENTIFIERS, + is_allowed_neptune_host, + false_str_variants, + DEFAULT_NEO4J_USERNAME, + DEFAULT_NEO4J_PASSWORD, + DEFAULT_NEO4J_DATABASE, + DEFAULT_MEMGRAPH_USERNAME, + DEFAULT_MEMGRAPH_PASSWORD, +) -neptune_params = ['auth_mode', 'load_from_s3_arn', 'aws_region'] +neptune_params = ["auth_mode", "load_from_s3_arn", "aws_region"] -def get_config_from_dict(data: dict, neptune_hosts: list = NEPTUNE_CONFIG_HOST_IDENTIFIERS) -> Configuration: - if 'ssl_verify' in data: - ssl_verify = False if data['ssl_verify'] in false_str_variants else True +def get_config_from_dict( + data: dict, neptune_hosts: list = NEPTUNE_CONFIG_HOST_IDENTIFIERS +) -> Configuration: + if "ssl_verify" in data: + ssl_verify = False if data["ssl_verify"] in false_str_variants else True else: ssl_verify = True - sparql_section = SparqlSection(**data['sparql']) if 'sparql' in data else SparqlSection('') - gremlin_section = GremlinSection(**data['gremlin']) if 'gremlin' in data else GremlinSection() - neo4j_section = Neo4JSection(**data['neo4j']) if 'neo4j' in data else Neo4JSection('', '', True, '') - proxy_host = str(data['proxy_host']) if 'proxy_host' in data else '' - proxy_port = int(data['proxy_port']) if 'proxy_port' in data else 8182 + sparql_section = ( + SparqlSection(**data["sparql"]) if "sparql" in data else SparqlSection("") + ) + gremlin_section = ( + GremlinSection(**data["gremlin"]) if "gremlin" in data else GremlinSection() + ) + neo4j_section = ( + Neo4JSection(**data["neo4j"]) + if "neo4j" in data + else Neo4JSection("", "", True, "") + ) + memgraph_section = ( + MemgraphSection(**data["memgraph"]) + if "memgraph" in data + else MemgraphSection("", "", False) + ) + proxy_host = str(data["proxy_host"]) if "proxy_host" in data else "" + proxy_port = int(data["proxy_port"]) if "proxy_port" in data else 8182 - is_neptune_host = is_allowed_neptune_host(hostname=data["host"], host_allowlist=neptune_hosts) + is_neptune_host = is_allowed_neptune_host( + hostname=data["host"], host_allowlist=neptune_hosts + ) if is_neptune_host: - if gremlin_section.to_dict()['traversal_source'] != 'g': - print('Ignoring custom traversal source, Amazon Neptune does not support this functionality.\n') - if neo4j_section.to_dict()['username'] != DEFAULT_NEO4J_USERNAME \ - or neo4j_section.to_dict()['password'] != DEFAULT_NEO4J_PASSWORD: - print('Ignoring Neo4J custom authentication, Amazon Neptune does not support this functionality.\n') - if neo4j_section.to_dict()['database'] != DEFAULT_NEO4J_DATABASE: - print('Ignoring Neo4J custom database, Amazon Neptune does not support multiple databases.\n') - config = Configuration(host=data['host'], port=data['port'], auth_mode=AuthModeEnum(data['auth_mode']), - ssl=data['ssl'], ssl_verify=ssl_verify, load_from_s3_arn=data['load_from_s3_arn'], - aws_region=data['aws_region'], sparql_section=sparql_section, - gremlin_section=gremlin_section, neo4j_section=neo4j_section, - proxy_host=proxy_host, proxy_port=proxy_port, neptune_hosts=neptune_hosts) + if gremlin_section.to_dict()["traversal_source"] != "g": + print( + "Ignoring custom traversal source, Amazon Neptune does not support this functionality.\n" + ) + if ( + neo4j_section.to_dict()["username"] != DEFAULT_NEO4J_USERNAME + or neo4j_section.to_dict()["password"] != DEFAULT_NEO4J_PASSWORD + ): + print( + "Ignoring Neo4J custom authentication, Amazon Neptune does not support this functionality.\n" + ) + if neo4j_section.to_dict()["database"] != DEFAULT_NEO4J_DATABASE: + print( + "Ignoring Neo4J custom database, Amazon Neptune does not support multiple databases.\n" + ) + if ( + memgraph_section.to_dict()["username"] != DEFAULT_MEMGRAPH_USERNAME + or memgraph_section.to_dict()["password"] != DEFAULT_MEMGRAPH_PASSWORD + ): + print( + "Ignoring Memgraph custom authentication, Amazon Neptune does not support this functionality.\n" + ) + config = Configuration( + host=data["host"], + port=data["port"], + auth_mode=AuthModeEnum(data["auth_mode"]), + ssl=data["ssl"], + ssl_verify=ssl_verify, + load_from_s3_arn=data["load_from_s3_arn"], + aws_region=data["aws_region"], + sparql_section=sparql_section, + gremlin_section=gremlin_section, + neo4j_section=neo4j_section, + memgraph_section=memgraph_section, + proxy_host=proxy_host, + proxy_port=proxy_port, + neptune_hosts=neptune_hosts, + ) else: excluded_params = [] for p in neptune_params: if p in data: excluded_params.append(p) if excluded_params: - print(f"The provided configuration contains the following parameters that are incompatible with the " - f"specified host: {str(excluded_params)}. These parameters have not been saved.\n") + print( + f"The provided configuration contains the following parameters that are incompatible with the " + f"specified host: {str(excluded_params)}. These parameters have not been saved.\n" + ) - config = Configuration(host=data['host'], port=data['port'], ssl=data['ssl'], ssl_verify=ssl_verify, - sparql_section=sparql_section, gremlin_section=gremlin_section, neo4j_section=neo4j_section, - proxy_host=proxy_host, proxy_port=proxy_port) + config = Configuration( + host=data["host"], + port=data["port"], + ssl=data["ssl"], + ssl_verify=ssl_verify, + sparql_section=sparql_section, + gremlin_section=gremlin_section, + neo4j_section=neo4j_section, + memgraph_section=memgraph_section, + proxy_host=proxy_host, + proxy_port=proxy_port, + ) return config -def get_config(path: str = DEFAULT_CONFIG_LOCATION, - neptune_hosts: list = NEPTUNE_CONFIG_HOST_IDENTIFIERS) -> Configuration: +def get_config( + path: str = DEFAULT_CONFIG_LOCATION, + neptune_hosts: list = NEPTUNE_CONFIG_HOST_IDENTIFIERS, +) -> Configuration: with open(path) as config_file: data = json.load(config_file) return get_config_from_dict(data=data, neptune_hosts=neptune_hosts) diff --git a/src/graph_notebook/magics/graph_magic.py b/src/graph_notebook/magics/graph_magic.py index 1eb3d90f..d3cecdc9 100644 --- a/src/graph_notebook/magics/graph_magic.py +++ b/src/graph_notebook/magics/graph_magic.py @@ -30,32 +30,81 @@ from botocore.session import get_session from gremlin_python.driver.protocol import GremlinServerError from IPython.core.display import HTML, display_html, display -from IPython.core.magic import (Magics, magics_class, cell_magic, line_magic, line_cell_magic, needs_local_scope) +from IPython.core.magic import ( + Magics, + magics_class, + cell_magic, + line_magic, + line_cell_magic, + needs_local_scope, +) from ipywidgets.widgets.widget_description import DescriptionStyle from requests import HTTPError import graph_notebook -from graph_notebook.configuration.generate_config import generate_default_config, DEFAULT_CONFIG_LOCATION, \ - AuthModeEnum, Configuration +from graph_notebook.configuration.generate_config import ( + generate_default_config, + DEFAULT_CONFIG_LOCATION, + AuthModeEnum, + Configuration, +) from graph_notebook.decorators.decorators import display_exceptions, magic_variables -from graph_notebook.magics.ml import neptune_ml_magic_handler, generate_neptune_ml_parser +from graph_notebook.magics.ml import ( + neptune_ml_magic_handler, + generate_neptune_ml_parser, +) from graph_notebook.magics.streams import StreamViewer -from graph_notebook.neptune.client import ClientBuilder, Client, VALID_FORMATS, PARALLELISM_OPTIONS, PARALLELISM_HIGH, \ - LOAD_JOB_MODES, MODE_AUTO, FINAL_LOAD_STATUSES, SPARQL_ACTION, FORMAT_CSV, FORMAT_OPENCYPHER, FORMAT_NTRIPLE, \ - FORMAT_NQUADS, FORMAT_RDFXML, FORMAT_TURTLE, STREAM_RDF, STREAM_PG, STREAM_ENDPOINTS, \ - NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, \ - STATISTICS_LANGUAGE_INPUTS, STATISTICS_MODES, SUMMARY_MODES, \ - SPARQL_EXPLAIN_MODES, OPENCYPHER_EXPLAIN_MODES +from graph_notebook.neptune.client import ( + ClientBuilder, + Client, + VALID_FORMATS, + PARALLELISM_OPTIONS, + PARALLELISM_HIGH, + LOAD_JOB_MODES, + MODE_AUTO, + FINAL_LOAD_STATUSES, + SPARQL_ACTION, + FORMAT_CSV, + FORMAT_OPENCYPHER, + FORMAT_NTRIPLE, + FORMAT_NQUADS, + FORMAT_RDFXML, + FORMAT_TURTLE, + STREAM_RDF, + STREAM_PG, + STREAM_ENDPOINTS, + NEPTUNE_CONFIG_HOST_IDENTIFIERS, + is_allowed_neptune_host, + STATISTICS_LANGUAGE_INPUTS, + STATISTICS_MODES, + SUMMARY_MODES, + SPARQL_EXPLAIN_MODES, + OPENCYPHER_EXPLAIN_MODES, +) from graph_notebook.network import SPARQLNetwork -from graph_notebook.network.gremlin.GremlinNetwork import parse_pattern_list_str, GremlinNetwork -from graph_notebook.visualization.rows_and_columns import sparql_get_rows_and_columns, opencypher_get_rows_and_columns +from graph_notebook.network.gremlin.GremlinNetwork import ( + parse_pattern_list_str, + GremlinNetwork, +) +from graph_notebook.visualization.rows_and_columns import ( + sparql_get_rows_and_columns, + opencypher_get_rows_and_columns, +) from graph_notebook.visualization.template_retriever import retrieve_template from graph_notebook.configuration.get_config import get_config, get_config_from_dict -from graph_notebook.seed.load_query import get_data_sets, get_queries, normalize_model_name, normalize_language_name +from graph_notebook.seed.load_query import ( + get_data_sets, + get_queries, + normalize_model_name, + normalize_language_name, +) from graph_notebook.widgets import Force from graph_notebook.options import OPTIONS_DEFAULT_DIRECTED, vis_options_merge -from graph_notebook.magics.metadata import build_sparql_metadata_from_query, build_gremlin_metadata_from_query, \ - build_opencypher_metadata_from_query +from graph_notebook.magics.metadata import ( + build_sparql_metadata_from_query, + build_gremlin_metadata_from_query, + build_opencypher_metadata_from_query, +) sparql_table_template = retrieve_template("sparql_table.html") sparql_explain_template = retrieve_template("sparql_explain.html") @@ -69,7 +118,9 @@ error_template = retrieve_template("error.html") loading_wheel_html = loading_wheel_template.render() -DEFAULT_LAYOUT = widgets.Layout(max_height='600px', max_width='940px', overflow='scroll') +DEFAULT_LAYOUT = widgets.Layout( + max_height="600px", max_width="940px", overflow="scroll" +) UNRESTRICTED_LAYOUT = widgets.Layout() DEFAULT_PAGINATION_OPTIONS = [10, 25, 50, 100, -1] @@ -97,18 +148,18 @@ DEFAULT_MAX_RESULTS = 1000 -GREMLIN_CANCEL_HINT_MSG = '''You must supply a string queryId when using --cancelQuery, - for example: %gremlin_status --cancelQuery --queryId my-query-id''' -SPARQL_CANCEL_HINT_MSG = '''You must supply a string queryId when using --cancelQuery, - for example: %sparql_status --cancelQuery --queryId my-query-id''' -OPENCYPHER_CANCEL_HINT_MSG = '''You must supply a string queryId when using --cancelQuery, - for example: %opencypher_status --cancelQuery --queryId my-query-id''' -SEED_MODEL_OPTIONS = ['', 'propertygraph', 'rdf'] -SEED_LANGUAGE_OPTIONS = ['', 'gremlin', 'opencypher', 'sparql'] -SEED_SOURCE_OPTIONS = ['', 'samples', 'custom'] +GREMLIN_CANCEL_HINT_MSG = """You must supply a string queryId when using --cancelQuery, + for example: %gremlin_status --cancelQuery --queryId my-query-id""" +SPARQL_CANCEL_HINT_MSG = """You must supply a string queryId when using --cancelQuery, + for example: %sparql_status --cancelQuery --queryId my-query-id""" +OPENCYPHER_CANCEL_HINT_MSG = """You must supply a string queryId when using --cancelQuery, + for example: %opencypher_status --cancelQuery --queryId my-query-id""" +SEED_MODEL_OPTIONS = ["", "propertygraph", "rdf"] +SEED_LANGUAGE_OPTIONS = ["", "gremlin", "opencypher", "sparql"] +SEED_SOURCE_OPTIONS = ["", "samples", "custom"] SEED_NO_DATASETS_FOUND_MSG = "(No datasets available)" -LOADER_FORMAT_CHOICES = [''] +LOADER_FORMAT_CHOICES = [""] LOADER_FORMAT_CHOICES.extend(VALID_FORMATS) serializers_map = { @@ -116,7 +167,7 @@ "GRAPHSON_V2D0": "application/vnd.gremlin-v2.0+json", "GRAPHSON_V3D0": "application/vnd.gremlin-v3.0+json", "GRYO_V3D0": "application/vnd.gremlin-v3.0+gryo", - "GRAPHBINARY_V1D0": "application/vnd.graphbinary-v1.0" + "GRAPHBINARY_V1D0": "application/vnd.graphbinary-v1.0", } DEFAULT_NAMEDGRAPH_URI = "http://aws.amazon.com/neptune/vocab/v01/DefaultNamedGraph" @@ -142,19 +193,39 @@ MEDIA_TYPE_TRIG = "application/trig" MEDIA_TYPE_RDF4J_BINARY = "application/x-binary-rdf" -NEPTUNE_RDF_SELECT_FORMATS = [MEDIA_TYPE_SPARQL_JSON, MEDIA_TYPE_SPARQL_XML, MEDIA_TYPE_BINARY_RESULTS_TABLE, - MEDIA_TYPE_SPARQL_CSV, MEDIA_TYPE_SPARQL_TSV] -NEPTUNE_RDF_ASK_FORMATS = [MEDIA_TYPE_SPARQL_JSON, MEDIA_TYPE_SPARQL_XML, MEDIA_TYPE_BOOLEAN] -NEPTUNE_RDF_CONSTRUCT_DESCRIBE_FORMATS = [MEDIA_TYPE_SPARQL_JSON, MEDIA_TYPE_NQUADS, MEDIA_TYPE_NQUADS_TEXT, - MEDIA_TYPE_RDF_XML, MEDIA_TYPE_JSON_LD, MEDIA_TYPE_NTRIPLES, - MEDIA_TYPE_NTRIPLES_TEXT, MEDIA_TYPE_TURTLE, MEDIA_TYPE_N3, MEDIA_TYPE_TRIX, - MEDIA_TYPE_TRIG, MEDIA_TYPE_RDF4J_BINARY] +NEPTUNE_RDF_SELECT_FORMATS = [ + MEDIA_TYPE_SPARQL_JSON, + MEDIA_TYPE_SPARQL_XML, + MEDIA_TYPE_BINARY_RESULTS_TABLE, + MEDIA_TYPE_SPARQL_CSV, + MEDIA_TYPE_SPARQL_TSV, +] +NEPTUNE_RDF_ASK_FORMATS = [ + MEDIA_TYPE_SPARQL_JSON, + MEDIA_TYPE_SPARQL_XML, + MEDIA_TYPE_BOOLEAN, +] +NEPTUNE_RDF_CONSTRUCT_DESCRIBE_FORMATS = [ + MEDIA_TYPE_SPARQL_JSON, + MEDIA_TYPE_NQUADS, + MEDIA_TYPE_NQUADS_TEXT, + MEDIA_TYPE_RDF_XML, + MEDIA_TYPE_JSON_LD, + MEDIA_TYPE_NTRIPLES, + MEDIA_TYPE_NTRIPLES_TEXT, + MEDIA_TYPE_TURTLE, + MEDIA_TYPE_N3, + MEDIA_TYPE_TRIX, + MEDIA_TYPE_TRIG, + MEDIA_TYPE_RDF4J_BINARY, +] + class QueryMode(Enum): - DEFAULT = 'query' - EXPLAIN = 'explain' - PROFILE = 'profile' - EMPTY = '' + DEFAULT = "query" + EXPLAIN = "explain" + PROFILE = "profile" + EMPTY = "" def generate_seed_error_msg(error_content, file_name, line_index=None): @@ -166,7 +237,7 @@ def generate_seed_error_msg(error_content, file_name, line_index=None): def store_to_ns(key: str, value, ns: dict = None): - if key == '' or ns is None: + if key == "" or ns is None: return ns[key] = value @@ -178,29 +249,29 @@ def str_to_query_mode(s: str) -> QueryMode: if mode.value == s: return QueryMode(s) - logger.debug(f'Invalid query mode {s} supplied, defaulting to query.') + logger.debug(f"Invalid query mode {s} supplied, defaulting to query.") return QueryMode.DEFAULT ACTION_TO_QUERY_TYPE = { - 'sparql': 'application/sparql-query', - 'sparqlupdate': 'application/sparql-update' + "sparql": "application/sparql-query", + "sparqlupdate": "application/sparql-update", } def get_query_type(query): - s = SPARQLWrapper('') + s = SPARQLWrapper("") s.setQuery(query) return s.queryType def query_type_to_action(query_type): query_type = query_type.upper() - if query_type in ['SELECT', 'CONSTRUCT', 'ASK', 'DESCRIBE']: - return 'sparql' + if query_type in ["SELECT", "CONSTRUCT", "ASK", "DESCRIBE"]: + return "sparql" else: # TODO: check explicitly for all query types, raise exception for invalid query - return 'sparqlupdate' + return "sparqlupdate" def results_per_page_check(results_per_page): @@ -227,7 +298,13 @@ def generate_pagination_vars(visible_results: int): def replace_html_chars(result): - html_char_map = {"&": "&", "<": "<", ">": ">", '"': """, "'": "'"} + html_char_map = { + "&": "&", + "<": "<", + ">": ">", + '"': """, + "'": "'", + } fixed_result = str(result) for k, v in iter(html_char_map.items()): @@ -241,24 +318,33 @@ def get_load_ids(neptune_client): load_status.raise_for_status() res = load_status.json() ids = [] - if 'payload' in res and 'loadIds' in res['payload']: - ids = res['payload']['loadIds'] + if "payload" in res and "loadIds" in res["payload"]: + ids = res["payload"]["loadIds"] return ids, res def process_statistics_400(is_summary: bool, response): bad_request_res = json.loads(response.text) - res_code = bad_request_res['code'] - if res_code == 'StatisticsNotAvailableException': - print("No statistics found. Please ensure that auto-generation of DFE statistics is enabled by running " - "'%statistics' and checking if 'autoCompute' if set to True. Alternately, you can manually " - "trigger statistics generation by running: '%statistics --mode refresh'.") + res_code = bad_request_res["code"] + if res_code == "StatisticsNotAvailableException": + print( + "No statistics found. Please ensure that auto-generation of DFE statistics is enabled by running " + "'%statistics' and checking if 'autoCompute' if set to True. Alternately, you can manually " + "trigger statistics generation by running: '%statistics --mode refresh'." + ) elif res_code == "BadRequestException": - print("Unable to query the statistics endpoint. Please check that your Neptune instance is of size r5.large or " - "greater in order to have DFE statistics enabled.") - if is_summary and "Statistics is disabled" not in bad_request_res["detailedMessage"]: - print("\nPlease also note that the Graph Summary API is only available in Neptune engine version 1.2.1.0 " - "and later.") + print( + "Unable to query the statistics endpoint. Please check that your Neptune instance is of size r5.large or " + "greater in order to have DFE statistics enabled." + ) + if ( + is_summary + and "Statistics is disabled" not in bad_request_res["detailedMessage"] + ): + print( + "\nPlease also note that the Graph Summary API is only available in Neptune engine version 1.2.1.0 " + "and later." + ) else: print("Query encountered 400 error, please see below.") print(f"\nFull response: {bad_request_res}") @@ -275,12 +361,18 @@ def __init__(self, shell): self.neptune_cfg_allowlist = copy(NEPTUNE_CONFIG_HOST_IDENTIFIERS) self.graph_notebook_config = generate_default_config() try: - self.config_location = os.getenv('GRAPH_NOTEBOOK_CONFIG', DEFAULT_CONFIG_LOCATION) + self.config_location = os.getenv( + "GRAPH_NOTEBOOK_CONFIG", DEFAULT_CONFIG_LOCATION + ) self.client: Client = None - self.graph_notebook_config = get_config(self.config_location, neptune_hosts=self.neptune_cfg_allowlist) + self.graph_notebook_config = get_config( + self.config_location, neptune_hosts=self.neptune_cfg_allowlist + ) except FileNotFoundError: - print('Could not find a valid configuration. ' - 'Do not forget to validate your settings using %graph_notebook_config.') + print( + "Could not find a valid configuration. " + "Do not forget to validate your settings using %graph_notebook_config." + ) self.max_results = DEFAULT_MAX_RESULTS self.graph_notebook_vis_options = OPTIONS_DEFAULT_DIRECTED @@ -292,35 +384,50 @@ def _generate_client_from_config(self, config: Configuration): if self.client: self.client.close() - is_neptune_host = is_allowed_neptune_host(hostname=config.host, host_allowlist=self.neptune_cfg_allowlist) + is_neptune_host = is_allowed_neptune_host( + hostname=config.host, host_allowlist=self.neptune_cfg_allowlist + ) if is_neptune_host: - builder = ClientBuilder() \ - .with_host(config.host) \ - .with_port(config.port) \ - .with_region(config.aws_region) \ - .with_tls(config.ssl) \ - .with_ssl_verify(config.ssl_verify) \ - .with_proxy_host(config.proxy_host) \ - .with_proxy_port(config.proxy_port) \ - .with_sparql_path(config.sparql.path) \ + builder = ( + ClientBuilder() + .with_host(config.host) + .with_port(config.port) + .with_region(config.aws_region) + .with_tls(config.ssl) + .with_ssl_verify(config.ssl_verify) + .with_proxy_host(config.proxy_host) + .with_proxy_port(config.proxy_port) + .with_sparql_path(config.sparql.path) .with_gremlin_serializer(config.gremlin.message_serializer) + ) if config.auth_mode == AuthModeEnum.IAM: builder = builder.with_iam(get_session()) if self.neptune_cfg_allowlist != NEPTUNE_CONFIG_HOST_IDENTIFIERS: builder = builder.with_custom_neptune_hosts(self.neptune_cfg_allowlist) else: - builder = ClientBuilder() \ - .with_host(config.host) \ - .with_port(config.port) \ - .with_tls(config.ssl) \ - .with_ssl_verify(config.ssl_verify) \ - .with_sparql_path(config.sparql.path) \ - .with_gremlin_traversal_source(config.gremlin.traversal_source) \ - .with_gremlin_login(config.gremlin.username, config.gremlin.password) \ - .with_gremlin_serializer(config.gremlin.message_serializer) \ - .with_neo4j_login(config.neo4j.username, config.neo4j.password, config.neo4j.auth, - config.neo4j.database) + builder = ( + ClientBuilder() + .with_host(config.host) + .with_port(config.port) + .with_tls(config.ssl) + .with_ssl_verify(config.ssl_verify) + .with_sparql_path(config.sparql.path) + .with_gremlin_traversal_source(config.gremlin.traversal_source) + .with_gremlin_login(config.gremlin.username, config.gremlin.password) + .with_gremlin_serializer(config.gremlin.message_serializer) + .with_neo4j_login( + config.neo4j.username, + config.neo4j.password, + config.neo4j.auth, + config.neo4j.database, + ) + .with_memgraph_login( + config.memgraph.username, + config.memgraph.password, + config.memgraph.auth, + ) + ) self.client = builder.build() @@ -328,25 +435,38 @@ def _generate_client_from_config(self, config: Configuration): @line_cell_magic @needs_local_scope @display_exceptions - def graph_notebook_config(self, line='', cell='', local_ns: dict = None): + def graph_notebook_config(self, line="", cell="", local_ns: dict = None): parser = argparse.ArgumentParser() - parser.add_argument('mode', nargs='?', default='show', - help='mode (default=show) [show|reset|silent]') - parser.add_argument('--store-to', type=str, default='', help='store query result to this variable') + parser.add_argument( + "mode", + nargs="?", + default="show", + help="mode (default=show) [show|reset|silent]", + ) + parser.add_argument( + "--store-to", + type=str, + default="", + help="store query result to this variable", + ) args = parser.parse_args(line.split()) - if cell != '': + if cell != "": data = json.loads(cell) - config = get_config_from_dict(data, neptune_hosts=self.neptune_cfg_allowlist) + config = get_config_from_dict( + data, neptune_hosts=self.neptune_cfg_allowlist + ) self.graph_notebook_config = config self._generate_client_from_config(config) - print('set notebook config to:') + print("set notebook config to:") print(json.dumps(self.graph_notebook_config.to_dict(), indent=2)) - elif args.mode == 'reset': - self.graph_notebook_config = get_config(self.config_location, neptune_hosts=self.neptune_cfg_allowlist) - print('reset notebook config to:') + elif args.mode == "reset": + self.graph_notebook_config = get_config( + self.config_location, neptune_hosts=self.neptune_cfg_allowlist + ) + print("reset notebook config to:") print(json.dumps(self.graph_notebook_config.to_dict(), indent=2)) - elif args.mode == 'silent': + elif args.mode == "silent": """ silent option to that our neptune_menu extension can receive json instead of python Configuration object @@ -358,40 +478,52 @@ def graph_notebook_config(self, line='', cell='', local_ns: dict = None): config_dict = self.graph_notebook_config.to_dict() print(json.dumps(config_dict, indent=2)) - store_to_ns(args.store_to, json.dumps(self.graph_notebook_config.to_dict(), indent=2), local_ns) + store_to_ns( + args.store_to, + json.dumps(self.graph_notebook_config.to_dict(), indent=2), + local_ns, + ) return self.graph_notebook_config @line_cell_magic - def neptune_config_allowlist(self, line='', cell=''): + def neptune_config_allowlist(self, line="", cell=""): parser = argparse.ArgumentParser() - parser.add_argument('mode', nargs='?', default='add', - help='mode (default=add) [add|remove|overwrite|reset]') + parser.add_argument( + "mode", + nargs="?", + default="add", + help="mode (default=add) [add|remove|overwrite|reset]", + ) args = parser.parse_args(line.split()) try: cell_new = ast.literal_eval(cell) - input_type = 'list' + input_type = "list" except: cell_new = cell - input_type = 'string' + input_type = "string" allowlist_modified = True - if args.mode == 'reset': + if args.mode == "reset": self.neptune_cfg_allowlist = copy(NEPTUNE_CONFIG_HOST_IDENTIFIERS) - elif cell != '': - if args.mode == 'add': - if input_type == 'string': + elif cell != "": + if args.mode == "add": + if input_type == "string": self.neptune_cfg_allowlist.append(cell_new.strip()) else: - self.neptune_cfg_allowlist = list(set(self.neptune_cfg_allowlist) | set(cell_new)) - elif args.mode == 'remove': - if input_type == 'string': + self.neptune_cfg_allowlist = list( + set(self.neptune_cfg_allowlist) | set(cell_new) + ) + elif args.mode == "remove": + if input_type == "string": self.neptune_cfg_allowlist.remove(cell_new.strip()) else: - self.neptune_cfg_allowlist = list(set(self.neptune_cfg_allowlist) - set(cell_new)) - elif args.mode == 'overwrite': - if input_type == 'string': + self.neptune_cfg_allowlist = list( + set(self.neptune_cfg_allowlist) - set(cell_new) + ) + elif args.mode == "overwrite": + if input_type == "string": self.neptune_cfg_allowlist = [cell_new.strip()] else: self.neptune_cfg_allowlist = cell_new @@ -399,25 +531,34 @@ def neptune_config_allowlist(self, line='', cell=''): allowlist_modified = False if allowlist_modified: - print(f'Set Neptune config allow list to: {self.neptune_cfg_allowlist}') + print(f"Set Neptune config allow list to: {self.neptune_cfg_allowlist}") else: - print(f'Current Neptune config allow list: {self.neptune_cfg_allowlist}') + print(f"Current Neptune config allow list: {self.neptune_cfg_allowlist}") @line_magic - def stream_viewer(self,line): + def stream_viewer(self, line): parser = argparse.ArgumentParser() - parser.add_argument('language', nargs='?', default=STREAM_PG, - help=f'language (default={STREAM_PG}) [{STREAM_PG}|{STREAM_RDF}]', - choices = [STREAM_PG, STREAM_RDF]) + parser.add_argument( + "language", + nargs="?", + default=STREAM_PG, + help=f"language (default={STREAM_PG}) [{STREAM_PG}|{STREAM_RDF}]", + choices=[STREAM_PG, STREAM_RDF], + ) - parser.add_argument('--limit', type=int, default=10, help='Maximum number of rows to display at a time') + parser.add_argument( + "--limit", + type=int, + default=10, + help="Maximum number of rows to display at a time", + ) args = parser.parse_args(line.split()) language = args.language limit = args.limit uri = self.client.get_uri_with_port() - viewer = StreamViewer(self.client,uri,language,limit=limit) + viewer = StreamViewer(self.client, uri, language, limit=limit) viewer.show() @line_magic @@ -425,27 +566,51 @@ def stream_viewer(self,line): @display_exceptions def statistics(self, line, local_ns: dict = None): parser = argparse.ArgumentParser() - parser.add_argument('language', nargs='?', type=str.lower, default="propertygraph", - help=f'The language endpoint to use. Valid inputs: {STATISTICS_LANGUAGE_INPUTS}. ' - f'Default: propertygraph.', - choices=STATISTICS_LANGUAGE_INPUTS) - parser.add_argument('-m', '--mode', type=str, default='', - help=f'The action to perform on the statistics endpoint. Valid inputs: {STATISTICS_MODES}. ' - f'Default: `basic` if `--summary` is specified, otherwise `status`.') - parser.add_argument('--summary', action='store_true', default=False, help="Retrieves the graph summary.") - parser.add_argument('--silent', action='store_true', default=False, help="Display no output.") - parser.add_argument('--store-to', type=str, default='') + parser.add_argument( + "language", + nargs="?", + type=str.lower, + default="propertygraph", + help=f"The language endpoint to use. Valid inputs: {STATISTICS_LANGUAGE_INPUTS}. " + f"Default: propertygraph.", + choices=STATISTICS_LANGUAGE_INPUTS, + ) + parser.add_argument( + "-m", + "--mode", + type=str, + default="", + help=f"The action to perform on the statistics endpoint. Valid inputs: {STATISTICS_MODES}. " + f"Default: `basic` if `--summary` is specified, otherwise `status`.", + ) + parser.add_argument( + "--summary", + action="store_true", + default=False, + help="Retrieves the graph summary.", + ) + parser.add_argument( + "--silent", action="store_true", default=False, help="Display no output." + ) + parser.add_argument("--store-to", type=str, default="") args = parser.parse_args(line.split()) mode = args.mode if not mode: - mode = 'basic' if args.summary else 'status' - elif (args.summary and mode not in SUMMARY_MODES) or (not args.summary and mode not in STATISTICS_MODES): - err_endpoint_type, err_mode_list, err_default_mode = ("summary", SUMMARY_MODES[1:], "basic summary view") \ - if args.summary else ("statistics", STATISTICS_MODES[1:], "status") - print(f'Invalid {err_endpoint_type} mode. Please specify one of: {err_mode_list}, ' - f'or leave blank to retrieve {err_default_mode}.') + mode = "basic" if args.summary else "status" + elif (args.summary and mode not in SUMMARY_MODES) or ( + not args.summary and mode not in STATISTICS_MODES + ): + err_endpoint_type, err_mode_list, err_default_mode = ( + ("summary", SUMMARY_MODES[1:], "basic summary view") + if args.summary + else ("statistics", STATISTICS_MODES[1:], "status") + ) + print( + f"Invalid {err_endpoint_type} mode. Please specify one of: {err_mode_list}, " + f"or leave blank to retrieve {err_default_mode}." + ) return statistics_res = self.client.statistics(args.language, args.summary, mode) @@ -467,15 +632,26 @@ def statistics(self, line, local_ns: dict = None): @display_exceptions def summary(self, line, local_ns: dict = None): parser = argparse.ArgumentParser() - parser.add_argument('language', nargs='?', type=str.lower, default="propertygraph", - help=f'The language endpoint to use. Valid inputs: {STATISTICS_LANGUAGE_INPUTS}. ' - f'Default: propertygraph.', - choices=STATISTICS_LANGUAGE_INPUTS) - parser.add_argument('--detailed', action='store_true', default=False, - help="Toggles the display of structures fields on or off in the output. If not supplied, " - "we will default to the basic summary display mode.") - parser.add_argument('--silent', action='store_true', default=False, help="Display no output.") - parser.add_argument('--store-to', type=str, default='') + parser.add_argument( + "language", + nargs="?", + type=str.lower, + default="propertygraph", + help=f"The language endpoint to use. Valid inputs: {STATISTICS_LANGUAGE_INPUTS}. " + f"Default: propertygraph.", + choices=STATISTICS_LANGUAGE_INPUTS, + ) + parser.add_argument( + "--detailed", + action="store_true", + default=False, + help="Toggles the display of structures fields on or off in the output. If not supplied, " + "we will default to the basic summary display mode.", + ) + parser.add_argument( + "--silent", action="store_true", default=False, help="Display no output." + ) + parser.add_argument("--store-to", type=str, default="") args = parser.parse_args(line.split()) if args.detailed: @@ -496,67 +672,166 @@ def summary(self, line, local_ns: dict = None): @line_magic def graph_notebook_host(self, line): - if line == '': - print(f'current host: {self.graph_notebook_config.host}') + if line == "": + print(f"current host: {self.graph_notebook_config.host}") return # TODO: we should attempt to make a status call to this host before we set the config to this value. self.graph_notebook_config.host = line self._generate_client_from_config(self.graph_notebook_config) - print(f'set host to {self.graph_notebook_config.host}') + print(f"set host to {self.graph_notebook_config.host}") @magic_variables @cell_magic @needs_local_scope @display_exceptions - def sparql(self, line='', cell='', local_ns: dict = None): + def sparql(self, line="", cell="", local_ns: dict = None): parser = argparse.ArgumentParser() - parser.add_argument('query_mode', nargs='?', default='query', - help='query mode (default=query) [query|explain]') - parser.add_argument('--path', '-p', default='', - help='prefix path to sparql endpoint. For example, if "foo/bar" were specified, ' - 'the endpoint called would be host:port/foo/bar') - parser.add_argument('--expand-all', action='store_true') - parser.add_argument('--explain-type', type=str.lower, default='dynamic', - help=f'Explain mode to use when using the explain query mode. ' - f'Expected values: ${SPARQL_EXPLAIN_MODES}') - parser.add_argument('--explain-format', default='text/html', help='response format for explain query mode', - choices=['text/csv', 'text/html']) - parser.add_argument('-m', '--media-type', type=str, default='', - help='Response format for SELECT/CONSTRUCT/DESCRIBE queries. See ' - 'https://docs.aws.amazon.com/neptune/latest/userguide/sparql-media-type-support.html ' - 'for valid RDF media types supported by Neptune for each query type. Default for ' - 'Neptune and SELECT queries is application/sparql-results+json, otherwise no format ' - 'will be specified in the request.') - parser.add_argument('-g', '--group-by', type=str, default='', - help='Property used to group nodes.') - parser.add_argument('-gr', '--group-by-raw', action='store_true', default=False, - help="Group nodes by the raw binding") - parser.add_argument('-d', '--display-property', type=str, default='', - help='Property to display the value of on each node.') - parser.add_argument('-de', '--edge-display-property', type=str, default='', - help='Property to display the value of on each edge.') - parser.add_argument('-t', '--tooltip-property', type=str, default='', - help='Property to display the value of on each node tooltip.') - parser.add_argument('-te', '--edge-tooltip-property', type=str, default='', - help='Property to display the value of on each edge tooltip.') - parser.add_argument('-l', '--label-max-length', type=int, default=10, - help='Specifies max length of vertex labels, in characters. Default is 10') - parser.add_argument('-le', '--edge-label-max-length', type=int, default=10, - help='Specifies max length of edge labels, in characters. Default is 10') - parser.add_argument('--store-to', type=str, default='', help='store query result to this variable') - parser.add_argument('--ignore-groups', action='store_true', default=False, help="Ignore all grouping options") - parser.add_argument('-sp', '--stop-physics', action='store_true', default=False, - help="Disable visualization physics after the initial simulation stabilizes.") - parser.add_argument('-sd', '--simulation-duration', type=int, default=1500, - help='Specifies maximum duration of visualization physics simulation. Default is 1500ms') - parser.add_argument('--silent', action='store_true', default=False, help="Display no query output.") - parser.add_argument('-r', '--results-per-page', type=int, default=10, - help='Specifies how many query results to display per page in the output. Default is 10') - parser.add_argument('--no-scroll', action='store_true', default=False, - help="Display the entire output without a scroll bar.") - parser.add_argument('--hide-index', action='store_true', default=False, - help="Hide the index column numbers when displaying the results.") + parser.add_argument( + "query_mode", + nargs="?", + default="query", + help="query mode (default=query) [query|explain]", + ) + parser.add_argument( + "--path", + "-p", + default="", + help='prefix path to sparql endpoint. For example, if "foo/bar" were specified, ' + "the endpoint called would be host:port/foo/bar", + ) + parser.add_argument("--expand-all", action="store_true") + parser.add_argument( + "--explain-type", + type=str.lower, + default="dynamic", + help=f"Explain mode to use when using the explain query mode. " + f"Expected values: ${SPARQL_EXPLAIN_MODES}", + ) + parser.add_argument( + "--explain-format", + default="text/html", + help="response format for explain query mode", + choices=["text/csv", "text/html"], + ) + parser.add_argument( + "-m", + "--media-type", + type=str, + default="", + help="Response format for SELECT/CONSTRUCT/DESCRIBE queries. See " + "https://docs.aws.amazon.com/neptune/latest/userguide/sparql-media-type-support.html " + "for valid RDF media types supported by Neptune for each query type. Default for " + "Neptune and SELECT queries is application/sparql-results+json, otherwise no format " + "will be specified in the request.", + ) + parser.add_argument( + "-g", + "--group-by", + type=str, + default="", + help="Property used to group nodes.", + ) + parser.add_argument( + "-gr", + "--group-by-raw", + action="store_true", + default=False, + help="Group nodes by the raw binding", + ) + parser.add_argument( + "-d", + "--display-property", + type=str, + default="", + help="Property to display the value of on each node.", + ) + parser.add_argument( + "-de", + "--edge-display-property", + type=str, + default="", + help="Property to display the value of on each edge.", + ) + parser.add_argument( + "-t", + "--tooltip-property", + type=str, + default="", + help="Property to display the value of on each node tooltip.", + ) + parser.add_argument( + "-te", + "--edge-tooltip-property", + type=str, + default="", + help="Property to display the value of on each edge tooltip.", + ) + parser.add_argument( + "-l", + "--label-max-length", + type=int, + default=10, + help="Specifies max length of vertex labels, in characters. Default is 10", + ) + parser.add_argument( + "-le", + "--edge-label-max-length", + type=int, + default=10, + help="Specifies max length of edge labels, in characters. Default is 10", + ) + parser.add_argument( + "--store-to", + type=str, + default="", + help="store query result to this variable", + ) + parser.add_argument( + "--ignore-groups", + action="store_true", + default=False, + help="Ignore all grouping options", + ) + parser.add_argument( + "-sp", + "--stop-physics", + action="store_true", + default=False, + help="Disable visualization physics after the initial simulation stabilizes.", + ) + parser.add_argument( + "-sd", + "--simulation-duration", + type=int, + default=1500, + help="Specifies maximum duration of visualization physics simulation. Default is 1500ms", + ) + parser.add_argument( + "--silent", + action="store_true", + default=False, + help="Display no query output.", + ) + parser.add_argument( + "-r", + "--results-per-page", + type=int, + default=10, + help="Specifies how many query results to display per page in the output. Default is 10", + ) + parser.add_argument( + "--no-scroll", + action="store_true", + default=False, + help="Display the entire output without a scroll bar.", + ) + parser.add_argument( + "--hide-index", + action="store_true", + default=False, + help="Hide the index column numbers when displaying the results.", + ) args = parser.parse_args(line.split()) mode = str_to_query_mode(args.query_mode) @@ -579,23 +854,28 @@ def sparql(self, line='', cell='', local_ns: dict = None): first_tab_output = widgets.Output(layout=sparql_layout) children.append(first_tab_output) - path = args.path if args.path != '' else self.graph_notebook_config.sparql.path - logger.debug(f'using mode={mode}') + path = args.path if args.path != "" else self.graph_notebook_config.sparql.path + logger.debug(f"using mode={mode}") results_df = None if mode == QueryMode.EXPLAIN: - res = self.client.sparql_explain(cell, args.explain_type, args.explain_format, path=path) + res = self.client.sparql_explain( + cell, args.explain_type, args.explain_format, path=path + ) res.raise_for_status() - explain_bytes = res.content.replace(b'\xcc', b'-') - explain_bytes = explain_bytes.replace(b'\xb6', b'') - explain = explain_bytes.decode('utf-8') + explain_bytes = res.content.replace(b"\xcc", b"-") + explain_bytes = explain_bytes.replace(b"\xb6", b"") + explain = explain_bytes.decode("utf-8") store_to_ns(args.store_to, explain, local_ns) if not args.silent: - sparql_metadata = build_sparql_metadata_from_query(query_type='explain', res=res) - titles.append('Explain') - explain_bytes = explain.encode('ascii', 'ignore') - base64_str = base64.b64encode(explain_bytes).decode('ascii') - first_tab_html = sparql_explain_template.render(table=explain, - link=f"data:text/html;base64,{base64_str}") + sparql_metadata = build_sparql_metadata_from_query( + query_type="explain", res=res + ) + titles.append("Explain") + explain_bytes = explain.encode("ascii", "ignore") + base64_str = base64.b64encode(explain_bytes).decode("ascii") + first_tab_html = sparql_explain_template.render( + table=explain, link=f"data:text/html;base64,{base64_str}" + ) else: query_type = get_query_type(cell) @@ -607,23 +887,37 @@ def sparql(self, line='', cell='', local_ns: dict = None): # query type. We will only verify if media types are valid for Neptune # (https://docs.aws.amazon.com/neptune/latest/userguide/sparql-media-type-support.html). For other # databases, we will rely on the HTTP query response to tell if there is an issue with the format. - if is_allowed_neptune_host(self.graph_notebook_config.host, NEPTUNE_CONFIG_HOST_IDENTIFIERS): - if (query_type == 'SELECT' and result_type not in NEPTUNE_RDF_SELECT_FORMATS) \ - or (query_type == 'ASK' and result_type not in NEPTUNE_RDF_ASK_FORMATS) \ - or (query_type in ['CONSTRUCT', 'DESCRIBE'] - and result_type not in NEPTUNE_RDF_CONSTRUCT_DESCRIBE_FORMATS) \ - or result_type == '': - if result_type != '': - print(f"Invalid media type: {result_type} specified for Neptune {query_type} query. " - f"Defaulting to: {MEDIA_TYPE_SPARQL_JSON}.") + if is_allowed_neptune_host( + self.graph_notebook_config.host, NEPTUNE_CONFIG_HOST_IDENTIFIERS + ): + if ( + ( + query_type == "SELECT" + and result_type not in NEPTUNE_RDF_SELECT_FORMATS + ) + or ( + query_type == "ASK" + and result_type not in NEPTUNE_RDF_ASK_FORMATS + ) + or ( + query_type in ["CONSTRUCT", "DESCRIBE"] + and result_type not in NEPTUNE_RDF_CONSTRUCT_DESCRIBE_FORMATS + ) + or result_type == "" + ): + if result_type != "": + print( + f"Invalid media type: {result_type} specified for Neptune {query_type} query. " + f"Defaulting to: {MEDIA_TYPE_SPARQL_JSON}." + ) result_type = MEDIA_TYPE_SPARQL_JSON - headers = {'Accept': result_type} - elif result_type == '': - if query_type == 'SELECT': + headers = {"Accept": result_type} + elif result_type == "": + if query_type == "SELECT": result_type = MEDIA_TYPE_SPARQL_JSON - headers = {'Accept': MEDIA_TYPE_SPARQL_JSON} + headers = {"Accept": MEDIA_TYPE_SPARQL_JSON} else: - headers = {'Accept': result_type} + headers = {"Accept": result_type} query_res = self.client.sparql(cell, path=path, headers=headers) @@ -632,14 +926,20 @@ def sparql(self, line='', cell='', local_ns: dict = None): except HTTPError: # Catching all 400 response errors here to try and fix possible invalid media type for db in headers. # Retry query once with RDF spec default media type. - result_type = MEDIA_TYPE_SPARQL_JSON if query_type == 'SELECT' else MEDIA_TYPE_NTRIPLES - query_res = self.client.sparql(cell, path=path, headers={'Accept': result_type}) + result_type = ( + MEDIA_TYPE_SPARQL_JSON + if query_type == "SELECT" + else MEDIA_TYPE_NTRIPLES + ) + query_res = self.client.sparql( + cell, path=path, headers={"Accept": result_type} + ) query_res.raise_for_status() try: results = query_res.json() except Exception: - results = query_res.content.decode('utf-8') + results = query_res.content.decode("utf-8") store_to_ns(args.store_to, results, local_ns) if not args.silent: @@ -653,24 +953,26 @@ def sparql(self, line='', cell='', local_ns: dict = None): with raw_output: print(results) children.append(raw_output) - titles.append('Raw') + titles.append("Raw") else: - if query_type in ['SELECT', 'CONSTRUCT', 'DESCRIBE']: + if query_type in ["SELECT", "CONSTRUCT", "DESCRIBE"]: # TODO: Serialize other result types to SPARQL JSON so we can create table and visualization - logger.debug('creating sparql network...') - - titles.append('Table') - - sn = SPARQLNetwork(group_by_property=args.group_by, - display_property=args.display_property, - edge_display_property=args.edge_display_property, - tooltip_property=args.tooltip_property, - edge_tooltip_property=args.edge_tooltip_property, - label_max_length=args.label_max_length, - edge_label_max_length=args.edge_label_max_length, - ignore_groups=args.ignore_groups, - expand_all=args.expand_all, - group_by_raw=args.group_by_raw) + logger.debug("creating sparql network...") + + titles.append("Table") + + sn = SPARQLNetwork( + group_by_property=args.group_by, + display_property=args.display_property, + edge_display_property=args.edge_display_property, + tooltip_property=args.tooltip_property, + edge_tooltip_property=args.edge_tooltip_property, + label_max_length=args.label_max_length, + edge_label_max_length=args.edge_label_max_length, + ignore_groups=args.ignore_groups, + expand_all=args.expand_all, + group_by_raw=args.group_by_raw, + ) sn.extract_prefix_declarations_from_query(cell) try: @@ -678,52 +980,68 @@ def sparql(self, line='', cell='', local_ns: dict = None): except ValueError as value_error: logger.debug(value_error) - logger.debug(f'number of nodes is {len(sn.graph.nodes)}') + logger.debug(f"number of nodes is {len(sn.graph.nodes)}") if len(sn.graph.nodes) > 0: - self.graph_notebook_vis_options['physics']['disablePhysicsAfterInitialSimulation'] \ - = args.stop_physics - self.graph_notebook_vis_options['physics']['simulationDuration'] = args.simulation_duration - f = Force(network=sn, options=self.graph_notebook_vis_options) - titles.append('Graph') + self.graph_notebook_vis_options["physics"][ + "disablePhysicsAfterInitialSimulation" + ] = args.stop_physics + self.graph_notebook_vis_options["physics"][ + "simulationDuration" + ] = args.simulation_duration + f = Force( + network=sn, options=self.graph_notebook_vis_options + ) + titles.append("Graph") children.append(f) - logger.debug('added sparql network to tabs') + logger.debug("added sparql network to tabs") rows_and_columns = sparql_get_rows_and_columns(results) if rows_and_columns is not None: table_id = f"table-{str(uuid.uuid4())[:8]}" - visible_results = results_per_page_check(args.results_per_page) - first_tab_html = sparql_table_template.render(columns=rows_and_columns['columns'], - rows=rows_and_columns['rows'], guid=table_id, - amount=visible_results) + visible_results = results_per_page_check( + args.results_per_page + ) + first_tab_html = sparql_table_template.render( + columns=rows_and_columns["columns"], + rows=rows_and_columns["rows"], + guid=table_id, + amount=visible_results, + ) # Handling CONSTRUCT and DESCRIBE on their own because we want to maintain the previous result # pattern of showing a tsv with each line being a result binding in addition to new ones. - if query_type == 'CONSTRUCT' or query_type == 'DESCRIBE': + if query_type == "CONSTRUCT" or query_type == "DESCRIBE": lines = [] - for b in results['results']['bindings']: - lines.append(f'{b["subject"]["value"]}\t{b["predicate"]["value"]}\t{b["object"]["value"]}') + for b in results["results"]["bindings"]: + lines.append( + f'{b["subject"]["value"]}\t{b["predicate"]["value"]}\t{b["object"]["value"]}' + ) raw_output = widgets.Output(layout=sparql_layout) with raw_output: html = sparql_construct_template.render(lines=lines) display(HTML(html)) children.append(raw_output) - titles.append('Raw') + titles.append("Raw") json_output = widgets.Output(layout=sparql_layout) with json_output: print(json.dumps(results, indent=2)) children.append(json_output) - titles.append('JSON') + titles.append("JSON") - sparql_metadata = build_sparql_metadata_from_query(query_type='query', res=query_res, results=results) + sparql_metadata = build_sparql_metadata_from_query( + query_type="query", res=query_res, results=results + ) if not args.silent: metadata_output = widgets.Output(layout=sparql_layout) children.append(metadata_output) - titles.append('Query Metadata') + titles.append("Query Metadata") if first_tab_html == "" and results_df is None: - tab.children = children[1:] # the first tab is empty, remove it and proceed + tab.children = children[ + 1: + ] # the first tab is empty, remove it and proceed else: tab.children = children @@ -737,27 +1055,34 @@ def sparql(self, line='', cell='', local_ns: dict = None): if results_df is not None: with first_tab_output: - visible_results, final_pagination_options, final_pagination_menu = generate_pagination_vars( - args.results_per_page) + ( + visible_results, + final_pagination_options, + final_pagination_menu, + ) = generate_pagination_vars(args.results_per_page) sparql_columndefs = [ {"width": "5%", "targets": 0}, {"visible": True, "targets": 0}, {"searchable": False, "targets": 0}, {"className": "nowrap dt-left", "targets": "_all"}, {"createdCell": JavascriptFunction(index_col_js), "targets": 0}, - {"createdCell": JavascriptFunction(cell_style_js), "targets": "_all"} + { + "createdCell": JavascriptFunction(cell_style_js), + "targets": "_all", + }, ] if args.hide_index: sparql_columndefs[1]["visible"] = False - show(results_df, - scrollX=True, - scrollY=sparql_scrollY, - columnDefs=sparql_columndefs, - paging=sparql_paging, - scrollCollapse=sparql_scrollCollapse, - lengthMenu=[final_pagination_options, final_pagination_menu], - pageLength=visible_results - ) + show( + results_df, + scrollX=True, + scrollY=sparql_scrollY, + columnDefs=sparql_columndefs, + paging=sparql_paging, + scrollCollapse=sparql_scrollCollapse, + lengthMenu=[final_pagination_options, final_pagination_menu], + pageLength=visible_results, + ) elif first_tab_html != "": with first_tab_output: display(HTML(first_tab_html)) @@ -765,18 +1090,37 @@ def sparql(self, line='', cell='', local_ns: dict = None): @line_magic @needs_local_scope @display_exceptions - def sparql_status(self, line='', local_ns: dict = None): + def sparql_status(self, line="", local_ns: dict = None): parser = argparse.ArgumentParser() - parser.add_argument('-q', '--queryId', default='', - help='The ID of a running SPARQL query. Only displays the status of the specified query.') - parser.add_argument('-c', '--cancelQuery', action='store_true', - help='Tells the status command to cancel a query. This parameter does not take a value') - parser.add_argument('-s', '--silent-cancel', action='store_true', - help='If silent_cancel=true then the running query is cancelled and the HTTP response code ' - 'is 200. If silent_cancel is not present or silent_cancel=false, ' - 'the query is cancelled with an HTTP 500 status code.') - parser.add_argument('--silent', action='store_true', default=False, help="Display no output.") - parser.add_argument('--store-to', type=str, default='', help='store query result to this variable') + parser.add_argument( + "-q", + "--queryId", + default="", + help="The ID of a running SPARQL query. Only displays the status of the specified query.", + ) + parser.add_argument( + "-c", + "--cancelQuery", + action="store_true", + help="Tells the status command to cancel a query. This parameter does not take a value", + ) + parser.add_argument( + "-s", + "--silent-cancel", + action="store_true", + help="If silent_cancel=true then the running query is cancelled and the HTTP response code " + "is 200. If silent_cancel is not present or silent_cancel=false, " + "the query is cancelled with an HTTP 500 status code.", + ) + parser.add_argument( + "--silent", action="store_true", default=False, help="Display no output." + ) + parser.add_argument( + "--store-to", + type=str, + default="", + help="store query result to this variable", + ) args = parser.parse_args(line.split()) if not args.cancelQuery: @@ -784,7 +1128,7 @@ def sparql_status(self, line='', local_ns: dict = None): status_res.raise_for_status() res = status_res.json() else: - if args.queryId == '': + if args.queryId == "": if not args.silent: print(SPARQL_CANCEL_HINT_MSG) return @@ -803,63 +1147,179 @@ def sparql_status(self, line='', local_ns: dict = None): @display_exceptions def gremlin(self, line, cell, local_ns: dict = None): parser = argparse.ArgumentParser() - parser.add_argument('query_mode', nargs='?', default='query', - help='query mode (default=query) [query|explain|profile]') - parser.add_argument('--explain-type', type=str.lower, default='', - help='Explain mode to use when using the explain query mode.') - parser.add_argument('-p', '--path-pattern', default='', help='path pattern') - parser.add_argument('-g', '--group-by', type=str, default='T.label', - help='Property used to group nodes (e.g. code, T.region) default is T.label') - parser.add_argument('-gd', '--group-by-depth', action='store_true', default=False, - help="Group nodes based on path hierarchy") - parser.add_argument('-gr', '--group-by-raw', action='store_true', default=False, - help="Group nodes by the raw result") - parser.add_argument('-d', '--display-property', type=str, default='T.label', - help='Property to display the value of on each node, default is T.label') - parser.add_argument('-de', '--edge-display-property', type=str, default='T.label', - help='Property to display the value of on each edge, default is T.label') - parser.add_argument('-t', '--tooltip-property', type=str, default='', - help='Property to display the value of on each node tooltip. If not specified, tooltip ' - 'will default to the node label value.') - parser.add_argument('-te', '--edge-tooltip-property', type=str, default='', - help='Property to display the value of on each edge tooltip. If not specified, tooltip ' - 'will default to the edge label value.') - parser.add_argument('-l', '--label-max-length', type=int, default=10, - help='Specifies max length of vertex label, in characters. Default is 10') - parser.add_argument('-le', '--edge-label-max-length', type=int, default=10, - help='Specifies max length of edge labels, in characters. Default is 10') - parser.add_argument('--store-to', type=str, default='', help='store query result to this variable') - parser.add_argument('--ignore-groups', action='store_true', default=False, help="Ignore all grouping options") - parser.add_argument('--profile-no-results', action='store_false', default=True, - help='Display only the result count. If not used, all query results will be displayed in ' - 'the profile report by default.') - parser.add_argument('--profile-chop', type=int, default=250, - help='Property to specify max length of profile results string. Default is 250') - parser.add_argument('--profile-serializer', type=str, default='application/json', - help='Specify how to serialize results. Allowed values are any of the valid MIME type or ' - 'TinkerPop driver "Serializers" enum values. Default is application/json') - parser.add_argument('--profile-indexOps', action='store_true', default=False, - help='Show a detailed report of all index operations.') - parser.add_argument('--profile-misc-args', type=str, default='{}', - help='Additional profile options, passed in as a map.') - parser.add_argument('-sp', '--stop-physics', action='store_true', default=False, - help="Disable visualization physics after the initial simulation stabilizes.") - parser.add_argument('-sd', '--simulation-duration', type=int, default=1500, - help='Specifies maximum duration of visualization physics simulation. Default is 1500ms') - parser.add_argument('--silent', action='store_true', default=False, help="Display no query output.") - parser.add_argument('-r', '--results-per-page', type=int, default=10, - help='Specifies how many query results to display per page in the output. Default is 10') - parser.add_argument('--no-scroll', action='store_true', default=False, - help="Display the entire output without a scroll bar.") - parser.add_argument('--hide-index', action='store_true', default=False, - help="Hide the index column numbers when displaying the results.") - parser.add_argument('-mcl', '--max-content-length', type=int, default=10*1024*1024, - help="Specifies maximum size (in bytes) of results that can be returned to the " - "GremlinPython client. Default is 10MB") + parser.add_argument( + "query_mode", + nargs="?", + default="query", + help="query mode (default=query) [query|explain|profile]", + ) + parser.add_argument( + "--explain-type", + type=str.lower, + default="", + help="Explain mode to use when using the explain query mode.", + ) + parser.add_argument("-p", "--path-pattern", default="", help="path pattern") + parser.add_argument( + "-g", + "--group-by", + type=str, + default="T.label", + help="Property used to group nodes (e.g. code, T.region) default is T.label", + ) + parser.add_argument( + "-gd", + "--group-by-depth", + action="store_true", + default=False, + help="Group nodes based on path hierarchy", + ) + parser.add_argument( + "-gr", + "--group-by-raw", + action="store_true", + default=False, + help="Group nodes by the raw result", + ) + parser.add_argument( + "-d", + "--display-property", + type=str, + default="T.label", + help="Property to display the value of on each node, default is T.label", + ) + parser.add_argument( + "-de", + "--edge-display-property", + type=str, + default="T.label", + help="Property to display the value of on each edge, default is T.label", + ) + parser.add_argument( + "-t", + "--tooltip-property", + type=str, + default="", + help="Property to display the value of on each node tooltip. If not specified, tooltip " + "will default to the node label value.", + ) + parser.add_argument( + "-te", + "--edge-tooltip-property", + type=str, + default="", + help="Property to display the value of on each edge tooltip. If not specified, tooltip " + "will default to the edge label value.", + ) + parser.add_argument( + "-l", + "--label-max-length", + type=int, + default=10, + help="Specifies max length of vertex label, in characters. Default is 10", + ) + parser.add_argument( + "-le", + "--edge-label-max-length", + type=int, + default=10, + help="Specifies max length of edge labels, in characters. Default is 10", + ) + parser.add_argument( + "--store-to", + type=str, + default="", + help="store query result to this variable", + ) + parser.add_argument( + "--ignore-groups", + action="store_true", + default=False, + help="Ignore all grouping options", + ) + parser.add_argument( + "--profile-no-results", + action="store_false", + default=True, + help="Display only the result count. If not used, all query results will be displayed in " + "the profile report by default.", + ) + parser.add_argument( + "--profile-chop", + type=int, + default=250, + help="Property to specify max length of profile results string. Default is 250", + ) + parser.add_argument( + "--profile-serializer", + type=str, + default="application/json", + help="Specify how to serialize results. Allowed values are any of the valid MIME type or " + 'TinkerPop driver "Serializers" enum values. Default is application/json', + ) + parser.add_argument( + "--profile-indexOps", + action="store_true", + default=False, + help="Show a detailed report of all index operations.", + ) + parser.add_argument( + "--profile-misc-args", + type=str, + default="{}", + help="Additional profile options, passed in as a map.", + ) + parser.add_argument( + "-sp", + "--stop-physics", + action="store_true", + default=False, + help="Disable visualization physics after the initial simulation stabilizes.", + ) + parser.add_argument( + "-sd", + "--simulation-duration", + type=int, + default=1500, + help="Specifies maximum duration of visualization physics simulation. Default is 1500ms", + ) + parser.add_argument( + "--silent", + action="store_true", + default=False, + help="Display no query output.", + ) + parser.add_argument( + "-r", + "--results-per-page", + type=int, + default=10, + help="Specifies how many query results to display per page in the output. Default is 10", + ) + parser.add_argument( + "--no-scroll", + action="store_true", + default=False, + help="Display the entire output without a scroll bar.", + ) + parser.add_argument( + "--hide-index", + action="store_true", + default=False, + help="Hide the index column numbers when displaying the results.", + ) + parser.add_argument( + "-mcl", + "--max-content-length", + type=int, + default=10 * 1024 * 1024, + help="Specifies maximum size (in bytes) of results that can be returned to the " + "GremlinPython client. Default is 10MB", + ) args = parser.parse_args(line.split()) mode = str_to_query_mode(args.query_mode) - logger.debug(f'Arguments {args}') + logger.debug(f"Arguments {args}") results_df = None if args.no_scroll: @@ -881,109 +1341,139 @@ def gremlin(self, line, cell, local_ns: dict = None): first_tab_output = widgets.Output(layout=gremlin_layout) children.append(first_tab_output) - transport_args = {'max_content_length': args.max_content_length} + transport_args = {"max_content_length": args.max_content_length} if mode == QueryMode.EXPLAIN: - res = self.client.gremlin_explain(cell, - args={'explain.mode': args.explain_type} if args.explain_type else {}) + res = self.client.gremlin_explain( + cell, + args={"explain.mode": args.explain_type} if args.explain_type else {}, + ) res.raise_for_status() # Replace strikethrough character bytes, can't be encoded to ASCII - explain_bytes = res.content.replace(b'\xcc', b'-') - explain_bytes = explain_bytes.replace(b'\xb6', b'') - query_res = explain_bytes.decode('utf-8') + explain_bytes = res.content.replace(b"\xcc", b"-") + explain_bytes = explain_bytes.replace(b"\xb6", b"") + query_res = explain_bytes.decode("utf-8") if not args.silent: - gremlin_metadata = build_gremlin_metadata_from_query(query_type='explain', results=query_res, res=res) - titles.append('Explain') - if 'Neptune Gremlin Explain' in query_res: - explain_bytes = query_res.encode('ascii', 'ignore') - base64_str = base64.b64encode(explain_bytes).decode('ascii') - first_tab_html = gremlin_explain_profile_template.render(content=query_res, - link=f"data:text/html;base64,{base64_str}") + gremlin_metadata = build_gremlin_metadata_from_query( + query_type="explain", results=query_res, res=res + ) + titles.append("Explain") + if "Neptune Gremlin Explain" in query_res: + explain_bytes = query_res.encode("ascii", "ignore") + base64_str = base64.b64encode(explain_bytes).decode("ascii") + first_tab_html = gremlin_explain_profile_template.render( + content=query_res, link=f"data:text/html;base64,{base64_str}" + ) else: - first_tab_html = pre_container_template.render(content='No explain found') + first_tab_html = pre_container_template.render( + content="No explain found" + ) elif mode == QueryMode.PROFILE: - logger.debug(f'results: {args.profile_no_results}') - logger.debug(f'chop: {args.profile_chop}') - logger.debug(f'serializer: {args.profile_serializer}') - logger.debug(f'indexOps: {args.profile_indexOps}') + logger.debug(f"results: {args.profile_no_results}") + logger.debug(f"chop: {args.profile_chop}") + logger.debug(f"serializer: {args.profile_serializer}") + logger.debug(f"indexOps: {args.profile_indexOps}") if args.profile_serializer in serializers_map: serializer = serializers_map[args.profile_serializer] else: serializer = args.profile_serializer - profile_args = {"profile.results": args.profile_no_results, - "profile.chop": args.profile_chop, - "profile.serializer": serializer, - "profile.indexOps": args.profile_indexOps} + profile_args = { + "profile.results": args.profile_no_results, + "profile.chop": args.profile_chop, + "profile.serializer": serializer, + "profile.indexOps": args.profile_indexOps, + } try: profile_misc_args_dict = json.loads(args.profile_misc_args) profile_args.update(profile_misc_args_dict) except JSONDecodeError: - print('--profile-misc-args received invalid input, please check that you are passing in a valid ' - 'string representation of a map, ex. "{\'profile.x\':\'true\'}"') + print( + "--profile-misc-args received invalid input, please check that you are passing in a valid " + "string representation of a map, ex. \"{'profile.x':'true'}\"" + ) res = self.client.gremlin_profile(query=cell, args=profile_args) res.raise_for_status() - profile_bytes = res.content.replace(b'\xcc', b'-') - profile_bytes = profile_bytes.replace(b'\xb6', b'') - query_res = profile_bytes.decode('utf-8') + profile_bytes = res.content.replace(b"\xcc", b"-") + profile_bytes = profile_bytes.replace(b"\xb6", b"") + query_res = profile_bytes.decode("utf-8") if not args.silent: - gremlin_metadata = build_gremlin_metadata_from_query(query_type='profile', results=query_res, res=res) - titles.append('Profile') - if 'Neptune Gremlin Profile' in query_res: - explain_bytes = query_res.encode('ascii', 'ignore') - base64_str = base64.b64encode(explain_bytes).decode('ascii') - first_tab_html = gremlin_explain_profile_template.render(content=query_res, - link=f"data:text/html;base64,{base64_str}") + gremlin_metadata = build_gremlin_metadata_from_query( + query_type="profile", results=query_res, res=res + ) + titles.append("Profile") + if "Neptune Gremlin Profile" in query_res: + explain_bytes = query_res.encode("ascii", "ignore") + base64_str = base64.b64encode(explain_bytes).decode("ascii") + first_tab_html = gremlin_explain_profile_template.render( + content=query_res, link=f"data:text/html;base64,{base64_str}" + ) else: - first_tab_html = pre_container_template.render(content='No profile found') + first_tab_html = pre_container_template.render( + content="No profile found" + ) else: - query_start = time.time() * 1000 # time.time() returns time in seconds w/high precision; x1000 to get in ms + query_start = ( + time.time() * 1000 + ) # time.time() returns time in seconds w/high precision; x1000 to get in ms query_res = self.client.gremlin_query(cell, transport_args=transport_args) query_time = time.time() * 1000 - query_start if not args.silent: - gremlin_metadata = build_gremlin_metadata_from_query(query_type='query', results=query_res, - query_time=query_time) - titles.append('Console') + gremlin_metadata = build_gremlin_metadata_from_query( + query_type="query", results=query_res, query_time=query_time + ) + titles.append("Console") gremlin_network = None try: - logger.debug(f'groupby: {args.group_by}') - logger.debug(f'display_property: {args.display_property}') - logger.debug(f'edge_display_property: {args.edge_display_property}') - logger.debug(f'label_max_length: {args.label_max_length}') - logger.debug(f'ignore_groups: {args.ignore_groups}') - gn = GremlinNetwork(group_by_property=args.group_by, display_property=args.display_property, - group_by_raw=args.group_by_raw, - group_by_depth=args.group_by_depth, - edge_display_property=args.edge_display_property, - tooltip_property=args.tooltip_property, - edge_tooltip_property=args.edge_tooltip_property, - label_max_length=args.label_max_length, - edge_label_max_length=args.edge_label_max_length, - ignore_groups=args.ignore_groups) - - if args.path_pattern == '': + logger.debug(f"groupby: {args.group_by}") + logger.debug(f"display_property: {args.display_property}") + logger.debug(f"edge_display_property: {args.edge_display_property}") + logger.debug(f"label_max_length: {args.label_max_length}") + logger.debug(f"ignore_groups: {args.ignore_groups}") + gn = GremlinNetwork( + group_by_property=args.group_by, + display_property=args.display_property, + group_by_raw=args.group_by_raw, + group_by_depth=args.group_by_depth, + edge_display_property=args.edge_display_property, + tooltip_property=args.tooltip_property, + edge_tooltip_property=args.edge_tooltip_property, + label_max_length=args.label_max_length, + edge_label_max_length=args.edge_label_max_length, + ignore_groups=args.ignore_groups, + ) + + if args.path_pattern == "": gn.add_results(query_res) else: pattern = parse_pattern_list_str(args.path_pattern) gn.add_results_with_pattern(query_res, pattern) gremlin_network = gn - logger.debug(f'number of nodes is {len(gn.graph.nodes)}') + logger.debug(f"number of nodes is {len(gn.graph.nodes)}") except ValueError as value_error: logger.debug( - f'Unable to create graph network from result due to error: {value_error}. ' - f'Skipping from result set.') + f"Unable to create graph network from result due to error: {value_error}. " + f"Skipping from result set." + ) if gremlin_network and len(gremlin_network.graph.nodes) > 0: try: - self.graph_notebook_vis_options['physics']['disablePhysicsAfterInitialSimulation'] \ - = args.stop_physics - self.graph_notebook_vis_options['physics']['simulationDuration'] = args.simulation_duration - f = Force(network=gremlin_network, options=self.graph_notebook_vis_options) - titles.append('Graph') + self.graph_notebook_vis_options["physics"][ + "disablePhysicsAfterInitialSimulation" + ] = args.stop_physics + self.graph_notebook_vis_options["physics"][ + "simulationDuration" + ] = args.simulation_duration + f = Force( + network=gremlin_network, + options=self.graph_notebook_vis_options, + ) + titles.append("Graph") children.append(f) - logger.debug('added gremlin network to tabs') + logger.debug("added gremlin network to tabs") except Exception as force_error: logger.debug( - f'Unable to render visualization from graph network due to error: {force_error}. Skipping.') + f"Unable to render visualization from graph network due to error: {force_error}. Skipping." + ) # Check if we can access the CDNs required by itables library. # If not, then render our own HTML template. @@ -996,7 +1486,7 @@ def gremlin(self, line, cell, local_ns: dict = None): if not all(isinstance(x, dict) for x in query_res[1:]): mixed_results = True query_res_deque = deque(query_res) - query_res_deque.appendleft('x') + query_res_deque.appendleft("x") query_res = list(query_res_deque) results_df = pd.DataFrame(query_res) @@ -1007,23 +1497,25 @@ def gremlin(self, line, cell, local_ns: dict = None): for result in query_res: fixed_result = replace_html_chars(result) query_res_reformat.append([fixed_result]) - query_res_reformat.append([{'__DUMMY_KEY__': ['DUMMY_VALUE']}]) + query_res_reformat.append([{"__DUMMY_KEY__": ["DUMMY_VALUE"]}]) results_df = pd.DataFrame(query_res_reformat) if mixed_results: results_df = results_df[1:] results_df.drop(results_df.index[-1], inplace=True) results_df.insert(0, "#", range(1, len(results_df) + 1)) if len(results_df.columns) == 2 and int(results_df.columns[1]) == 0: - results_df.rename({results_df.columns[1]: 'Result'}, axis='columns', inplace=True) + results_df.rename( + {results_df.columns[1]: "Result"}, axis="columns", inplace=True + ) else: results_df.insert(1, "Result", []) - results_df.set_index('#', inplace=True) + results_df.set_index("#", inplace=True) results_df.columns.name = results_df.index.name results_df.index.name = None if not args.silent: metadata_output = widgets.Output(layout=gremlin_layout) - titles.append('Query Metadata') + titles.append("Query Metadata") children.append(metadata_output) tab.children = children @@ -1036,8 +1528,11 @@ def gremlin(self, line, cell, local_ns: dict = None): with first_tab_output: if mode == QueryMode.DEFAULT: - visible_results, final_pagination_options, final_pagination_menu = generate_pagination_vars( - args.results_per_page) + ( + visible_results, + final_pagination_options, + final_pagination_menu, + ) = generate_pagination_vars(args.results_per_page) gremlin_columndefs = [ {"width": "5%", "targets": 0}, {"visible": True, "targets": 0}, @@ -1045,19 +1540,23 @@ def gremlin(self, line, cell, local_ns: dict = None): {"minWidth": "95%", "targets": 1}, {"className": "nowrap dt-left", "targets": "_all"}, {"createdCell": JavascriptFunction(index_col_js), "targets": 0}, - {"createdCell": JavascriptFunction(cell_style_js), "targets": "_all"}, + { + "createdCell": JavascriptFunction(cell_style_js), + "targets": "_all", + }, ] if args.hide_index: gremlin_columndefs[1]["visible"] = False - show(results_df, - scrollX=True, - scrollY=gremlin_scrollY, - columnDefs=gremlin_columndefs, - paging=gremlin_paging, - scrollCollapse=gremlin_scrollCollapse, - lengthMenu=[final_pagination_options, final_pagination_menu], - pageLength=visible_results - ) + show( + results_df, + scrollX=True, + scrollY=gremlin_scrollY, + columnDefs=gremlin_columndefs, + paging=gremlin_paging, + scrollCollapse=gremlin_scrollCollapse, + lengthMenu=[final_pagination_options, final_pagination_menu], + pageLength=visible_results, + ) else: # Explain/Profile display(HTML(first_tab_html)) @@ -1066,26 +1565,47 @@ def gremlin(self, line, cell, local_ns: dict = None): @line_magic @needs_local_scope @display_exceptions - def gremlin_status(self, line='', local_ns: dict = None): + def gremlin_status(self, line="", local_ns: dict = None): parser = argparse.ArgumentParser() - parser.add_argument('-q', '--queryId', default='', - help='The ID of a running Gremlin query. Only displays the status of the specified query.') - parser.add_argument('-c', '--cancelQuery', action='store_true', - help='Required for cancellation. Parameter has no corresponding value.') - parser.add_argument('-w', '--includeWaiting', action='store_true', - help='(Optional) Normally, only running queries are included in the response. ' - 'When the includeWaiting parameter is specified, ' - 'the status of all waiting queries is also returned.') - parser.add_argument('--silent', action='store_true', default=False, help="Display no output.") - parser.add_argument('--store-to', type=str, default='', help='store query result to this variable') + parser.add_argument( + "-q", + "--queryId", + default="", + help="The ID of a running Gremlin query. Only displays the status of the specified query.", + ) + parser.add_argument( + "-c", + "--cancelQuery", + action="store_true", + help="Required for cancellation. Parameter has no corresponding value.", + ) + parser.add_argument( + "-w", + "--includeWaiting", + action="store_true", + help="(Optional) Normally, only running queries are included in the response. " + "When the includeWaiting parameter is specified, " + "the status of all waiting queries is also returned.", + ) + parser.add_argument( + "--silent", action="store_true", default=False, help="Display no output." + ) + parser.add_argument( + "--store-to", + type=str, + default="", + help="store query result to this variable", + ) args = parser.parse_args(line.split()) if not args.cancelQuery: - status_res = self.client.gremlin_status(query_id=args.queryId, include_waiting=args.includeWaiting) + status_res = self.client.gremlin_status( + query_id=args.queryId, include_waiting=args.includeWaiting + ) status_res.raise_for_status() res = status_res.json() else: - if args.queryId == '': + if args.queryId == "": if not args.silent: print(GREMLIN_CANCEL_HINT_MSG) return @@ -1101,55 +1621,66 @@ def gremlin_status(self, line='', local_ns: dict = None): @cell_magic @needs_local_scope @display_exceptions - def oc(self, line='', cell='', local_ns: dict = None): + def oc(self, line="", cell="", local_ns: dict = None): self.handle_opencypher_query(line, cell, local_ns) @magic_variables @cell_magic @needs_local_scope @display_exceptions - def opencypher(self, line='', cell='', local_ns: dict = None): + def opencypher(self, line="", cell="", local_ns: dict = None): self.handle_opencypher_query(line, cell, local_ns) @line_magic @needs_local_scope @display_exceptions - def oc_status(self, line='', local_ns: dict = None): + def oc_status(self, line="", local_ns: dict = None): self.handle_opencypher_status(line, local_ns) @line_magic @needs_local_scope @display_exceptions - def opencypher_status(self, line='', local_ns: dict = None): + def opencypher_status(self, line="", local_ns: dict = None): self.handle_opencypher_status(line, local_ns) @line_magic @needs_local_scope @display_exceptions - def status(self, line='', local_ns: dict = None): - logger.info(f'calling for status on endpoint {self.graph_notebook_config.host}') + def status(self, line="", local_ns: dict = None): + logger.info(f"calling for status on endpoint {self.graph_notebook_config.host}") parser = argparse.ArgumentParser() - parser.add_argument('--silent', action='store_true', default=False, help="Display no output.") - parser.add_argument('--store-to', type=str, default='', help='store query result to this variable') + parser.add_argument( + "--silent", action="store_true", default=False, help="Display no output." + ) + parser.add_argument( + "--store-to", + type=str, + default="", + help="store query result to this variable", + ) args = parser.parse_args(line.split()) status_res = self.client.status() status_res.raise_for_status() try: res = status_res.json() - logger.info(f'got the json format response {res}') + logger.info(f"got the json format response {res}") store_to_ns(args.store_to, res, local_ns) if not args.silent: return res except ValueError: - logger.info(f'got the HTML format response {status_res.text}') + logger.info(f"got the HTML format response {status_res.text}") store_to_ns(args.store_to, status_res.text, local_ns) if not args.silent: if "blazegraph™ by SYSTAP" in status_res.text: - print("For more information on the status of your Blazegraph cluster, please visit: ") + print( + "For more information on the status of your Blazegraph cluster, please visit: " + ) print() - print(f'http://{self.graph_notebook_config.host}:{self.graph_notebook_config.port}' - f'/blazegraph/#status') + print( + f"http://{self.graph_notebook_config.host}:{self.graph_notebook_config.port}" + f"/blazegraph/#status" + ) print() return status_res @@ -1157,41 +1688,60 @@ def status(self, line='', local_ns: dict = None): @needs_local_scope @display_exceptions def db_reset(self, line, local_ns: dict = None): - logger.info(f'calling system endpoint {self.client.host}') + logger.info(f"calling system endpoint {self.client.host}") parser = argparse.ArgumentParser() - parser.add_argument('-g', '--generate-token', action='store_true', help='generate token for database reset') - parser.add_argument('-t', '--token', default='', help='perform database reset with given token') - parser.add_argument('-y', '--yes', action='store_true', help='skip the prompt and perform database reset') - parser.add_argument('-m', '--max-status-retries', type=int, default=10, - help='Specifies how many times we should attempt to check if the database reset has ' - 'completed, in intervals of 5 seconds. Default is 10') + parser.add_argument( + "-g", + "--generate-token", + action="store_true", + help="generate token for database reset", + ) + parser.add_argument( + "-t", "--token", default="", help="perform database reset with given token" + ) + parser.add_argument( + "-y", + "--yes", + action="store_true", + help="skip the prompt and perform database reset", + ) + parser.add_argument( + "-m", + "--max-status-retries", + type=int, + default=10, + help="Specifies how many times we should attempt to check if the database reset has " + "completed, in intervals of 5 seconds. Default is 10", + ) args = parser.parse_args(line.split()) generate_token = args.generate_token skip_prompt = args.yes - max_status_retries = args.max_status_retries if args.max_status_retries > 0 else 1 - if generate_token is False and args.token == '': + max_status_retries = ( + args.max_status_retries if args.max_status_retries > 0 else 1 + ) + if generate_token is False and args.token == "": if skip_prompt: initiate_res = self.client.initiate_reset() initiate_res.raise_for_status() res = initiate_res.json() - token = res['payload']['token'] + token = res["payload"]["token"] perform_reset_res = self.client.perform_reset(token) perform_reset_res.raise_for_status() - logger.info(f'got the response {res}') + logger.info(f"got the response {res}") res = perform_reset_res.json() return res output = widgets.Output() - source = 'Are you sure you want to delete all the data in your cluster?' + source = "Are you sure you want to delete all the data in your cluster?" label = widgets.Label(source) text_hbox = widgets.HBox([label]) check_box = widgets.Checkbox( value=False, disabled=False, indent=False, - description='I acknowledge that upon deletion the cluster data will no longer be available.', - layout=widgets.Layout(width='600px', margin='5px 5px 5px 5px') + description="I acknowledge that upon deletion the cluster data will no longer be available.", + layout=widgets.Layout(width="600px", margin="5px 5px 5px 5px"), ) button_delete = widgets.Button(description="Delete") button_cancel = widgets.Button(description="Cancel") @@ -1212,12 +1762,12 @@ def on_button_delete_clicked(b): if not check_box.value: with output: - print('Checkbox is not checked.') + print("Checkbox is not checked.") return - token = result['payload']['token'] + token = result["payload"]["token"] if token == "": with output: - print('Failed to get token.') + print("Failed to get token.") print(result) return @@ -1225,9 +1775,11 @@ def on_button_delete_clicked(b): perform_reset_res.raise_for_status() result = perform_reset_res.json() - if 'status' not in result or result['status'] != '200 OK': + if "status" not in result or result["status"] != "200 OK": with output: - print('Database reset failed, please try the operation again or reboot the cluster.') + print( + "Database reset failed, please try the operation again or reboot the cluster." + ) print(result) logger.error(result) return @@ -1250,7 +1802,7 @@ def on_button_delete_clicked(b): interval_output.clear_output() if time_elapsed > poll_interval: with interval_output: - print('checking status...') + print("checking status...") job_status_output.clear_output() new_interval = True try: @@ -1266,9 +1818,9 @@ def on_button_delete_clicked(b): continue job_status_output.clear_output() with job_status_output: - if interval_check_response["status"] == 'healthy': + if interval_check_response["status"] == "healthy": interval_output.close() - print('Database has been reset.') + print("Database has been reset.") return last_poll_time = time.time() else: @@ -1277,17 +1829,19 @@ def on_button_delete_clicked(b): display_html(HTML(loading_wheel_html)) new_interval = False with interval_output: - print(f'checking status in {time_remaining} seconds') + print(f"checking status in {time_remaining} seconds") time.sleep(1) with output: job_status_output.clear_output() interval_output.close() - total_status_wait = max_status_retries*poll_interval + total_status_wait = max_status_retries * poll_interval print(result) - if interval_check_response.get("status") != 'healthy': - print(f"Could not retrieve the status of the reset operation within the allotted time of " - f"{total_status_wait} seconds. If the database is not in healthy status after at least 1 " - f"minute, please try the operation again or reboot the cluster.") + if interval_check_response.get("status") != "healthy": + print( + f"Could not retrieve the status of the reset operation within the allotted time of " + f"{total_status_wait} seconds. If the database is not in healthy status after at least 1 " + f"minute, please try the operation again or reboot the cluster." + ) def on_button_cancel_clicked(b): text_hbox.close() @@ -1296,7 +1850,7 @@ def on_button_cancel_clicked(b): button_cancel.close() button_hbox.close() with output: - print('Database reset operation has been canceled.') + print("Database reset operation has been canceled.") button_delete.on_click(on_button_delete_clicked) button_cancel.on_click(on_button_cancel_clicked) @@ -1311,175 +1865,208 @@ def on_button_cancel_clicked(b): perform_res.raise_for_status() res = perform_res.json() - logger.info(f'got the response {res}') + logger.info(f"got the response {res}") return res @line_magic @needs_local_scope @display_exceptions - def load(self, line='', local_ns: dict = None): + def load(self, line="", local_ns: dict = None): # TODO: change widgets to let any arbitrary inputs be added by users parser = argparse.ArgumentParser() - parser.add_argument('-s', '--source', default='s3://') + parser.add_argument("-s", "--source", default="s3://") try: - parser.add_argument('-l', '--loader-arn', default=self.graph_notebook_config.load_from_s3_arn) + parser.add_argument( + "-l", + "--loader-arn", + default=self.graph_notebook_config.load_from_s3_arn, + ) except AttributeError: - print(f"Missing required configuration option 'load_from_s3_arn'. Please ensure that you have provided a " - "valid Neptune cluster endpoint URI in the 'host' field of %graph_notebook_config.") + print( + f"Missing required configuration option 'load_from_s3_arn'. Please ensure that you have provided a " + "valid Neptune cluster endpoint URI in the 'host' field of %graph_notebook_config." + ) return - parser.add_argument('-f', '--format', choices=LOADER_FORMAT_CHOICES, default=FORMAT_CSV) - parser.add_argument('-p', '--parallelism', choices=PARALLELISM_OPTIONS, default=PARALLELISM_HIGH) + parser.add_argument( + "-f", "--format", choices=LOADER_FORMAT_CHOICES, default=FORMAT_CSV + ) + parser.add_argument( + "-p", "--parallelism", choices=PARALLELISM_OPTIONS, default=PARALLELISM_HIGH + ) try: - parser.add_argument('-r', '--region', default=self.graph_notebook_config.aws_region) + parser.add_argument( + "-r", "--region", default=self.graph_notebook_config.aws_region + ) except AttributeError: - print("Missing required configuration option 'aws_region'. Please ensure that you have provided a " - "valid Neptune cluster endpoint URI in the 'host' field of %graph_notebook_config.") + print( + "Missing required configuration option 'aws_region'. Please ensure that you have provided a " + "valid Neptune cluster endpoint URI in the 'host' field of %graph_notebook_config." + ) return - parser.add_argument('--fail-on-failure', action='store_true', default=False) - parser.add_argument('--update-single-cardinality', action='store_true', default=True) - parser.add_argument('--store-to', type=str, default='', help='store query result to this variable') - parser.add_argument('--run', action='store_true', default=False) - parser.add_argument('-m', '--mode', choices=LOAD_JOB_MODES, default=MODE_AUTO) - parser.add_argument('-q', '--queue-request', action='store_true', default=False) - parser.add_argument('-d', '--dependencies', action='append', default=[]) - parser.add_argument('-e', '--no-edge-ids', action='store_true', default=False) - parser.add_argument('--named-graph-uri', type=str, default=DEFAULT_NAMEDGRAPH_URI, - help='The default graph for all RDF formats when no graph is specified. ' - 'Default is http://aws.amazon.com/neptune/vocab/v01/DefaultNamedGraph.') - parser.add_argument('--base-uri', type=str, default=DEFAULT_BASE_URI, - help='The base URI for RDF/XML and Turtle formats. ' - 'Default is http://aws.amazon.com/neptune/default') - parser.add_argument('--allow-empty-strings', action='store_true', default=False, - help='Load empty strings found in node and edge property values.') - parser.add_argument('-n', '--nopoll', action='store_true', default=False) + parser.add_argument("--fail-on-failure", action="store_true", default=False) + parser.add_argument( + "--update-single-cardinality", action="store_true", default=True + ) + parser.add_argument( + "--store-to", + type=str, + default="", + help="store query result to this variable", + ) + parser.add_argument("--run", action="store_true", default=False) + parser.add_argument("-m", "--mode", choices=LOAD_JOB_MODES, default=MODE_AUTO) + parser.add_argument("-q", "--queue-request", action="store_true", default=False) + parser.add_argument("-d", "--dependencies", action="append", default=[]) + parser.add_argument("-e", "--no-edge-ids", action="store_true", default=False) + parser.add_argument( + "--named-graph-uri", + type=str, + default=DEFAULT_NAMEDGRAPH_URI, + help="The default graph for all RDF formats when no graph is specified. " + "Default is http://aws.amazon.com/neptune/vocab/v01/DefaultNamedGraph.", + ) + parser.add_argument( + "--base-uri", + type=str, + default=DEFAULT_BASE_URI, + help="The base URI for RDF/XML and Turtle formats. " + "Default is http://aws.amazon.com/neptune/default", + ) + parser.add_argument( + "--allow-empty-strings", + action="store_true", + default=False, + help="Load empty strings found in node and edge property values.", + ) + parser.add_argument("-n", "--nopoll", action="store_true", default=False) args = parser.parse_args(line.split()) button = widgets.Button(description="Submit") output = widgets.Output() - widget_width = '25%' - label_width = '16%' + widget_width = "25%" + label_width = "16%" source = widgets.Text( value=args.source, - placeholder='Type something', + placeholder="Type something", disabled=False, - layout=widgets.Layout(width=widget_width) + layout=widgets.Layout(width=widget_width), ) arn = widgets.Text( value=args.loader_arn, - placeholder='Type something', + placeholder="Type something", disabled=False, - layout=widgets.Layout(width=widget_width) + layout=widgets.Layout(width=widget_width), ) source_format = widgets.Dropdown( options=LOADER_FORMAT_CHOICES, value=args.format, disabled=False, - layout=widgets.Layout(width=widget_width) + layout=widgets.Layout(width=widget_width), ) - ids_hbox_visibility = 'none' - gremlin_parser_options_hbox_visibility = 'none' - named_graph_hbox_visibility = 'none' - base_uri_hbox_visibility = 'none' + ids_hbox_visibility = "none" + gremlin_parser_options_hbox_visibility = "none" + named_graph_hbox_visibility = "none" + base_uri_hbox_visibility = "none" if source_format.value.lower() == FORMAT_CSV: - gremlin_parser_options_hbox_visibility = 'flex' + gremlin_parser_options_hbox_visibility = "flex" elif source_format.value.lower() == FORMAT_OPENCYPHER: - ids_hbox_visibility = 'flex' + ids_hbox_visibility = "flex" elif source_format.value.lower() in RDF_LOAD_FORMATS: - named_graph_hbox_visibility = 'flex' + named_graph_hbox_visibility = "flex" if source_format.value.lower() in BASE_URI_FORMATS: - base_uri_hbox_visibility = 'flex' + base_uri_hbox_visibility = "flex" region_box = widgets.Text( value=args.region, placeholder=args.region, disabled=False, - layout=widgets.Layout(width=widget_width) + layout=widgets.Layout(width=widget_width), ) fail_on_error = widgets.Dropdown( - options=['TRUE', 'FALSE'], + options=["TRUE", "FALSE"], value=str(args.fail_on_failure).upper(), disabled=False, - layout=widgets.Layout(width=widget_width) + layout=widgets.Layout(width=widget_width), ) parallelism = widgets.Dropdown( options=PARALLELISM_OPTIONS, value=args.parallelism, disabled=False, - layout=widgets.Layout(width=widget_width) + layout=widgets.Layout(width=widget_width), ) allow_empty_strings = widgets.Dropdown( - options=['TRUE', 'FALSE'], + options=["TRUE", "FALSE"], value=str(args.allow_empty_strings).upper(), disabled=False, - layout=widgets.Layout(display=gremlin_parser_options_hbox_visibility, - width=widget_width) + layout=widgets.Layout( + display=gremlin_parser_options_hbox_visibility, width=widget_width + ), ) named_graph_uri = widgets.Text( value=args.named_graph_uri, - placeholder='http://named-graph-string', + placeholder="http://named-graph-string", disabled=False, - layout=widgets.Layout(display=named_graph_hbox_visibility, - width=widget_width) + layout=widgets.Layout( + display=named_graph_hbox_visibility, width=widget_width + ), ) base_uri = widgets.Text( value=args.base_uri, - placeholder='http://base-uri-string', + placeholder="http://base-uri-string", disabled=False, - layout=widgets.Layout(display=base_uri_hbox_visibility, - width=widget_width) + layout=widgets.Layout(display=base_uri_hbox_visibility, width=widget_width), ) update_single_cardinality = widgets.Dropdown( - options=['TRUE', 'FALSE'], + options=["TRUE", "FALSE"], value=str(args.update_single_cardinality).upper(), disabled=False, - layout=widgets.Layout(width=widget_width) + layout=widgets.Layout(width=widget_width), ) mode = widgets.Dropdown( options=LOAD_JOB_MODES, value=args.mode, disabled=False, - layout=widgets.Layout(width=widget_width) + layout=widgets.Layout(width=widget_width), ) user_provided_edge_ids = widgets.Dropdown( - options=['TRUE', 'FALSE'], + options=["TRUE", "FALSE"], value=str(not args.no_edge_ids).upper(), disabled=False, - layout=widgets.Layout(display=ids_hbox_visibility, - width=widget_width) + layout=widgets.Layout(display=ids_hbox_visibility, width=widget_width), ) queue_request = widgets.Dropdown( - options=['TRUE', 'FALSE'], + options=["TRUE", "FALSE"], value=str(args.queue_request).upper(), disabled=False, - layout=widgets.Layout(width=widget_width) + layout=widgets.Layout(width=widget_width), ) dependencies = widgets.Textarea( value="\n".join(args.dependencies), - placeholder='load_A_id\nload_B_id', + placeholder="load_A_id\nload_B_id", disabled=False, - layout=widgets.Layout(width=widget_width) + layout=widgets.Layout(width=widget_width), ) poll_status = widgets.Dropdown( - options=['TRUE', 'FALSE'], + options=["TRUE", "FALSE"], value=str(not args.nopoll).upper(), disabled=False, - layout=widgets.Layout(width=widget_width) + layout=widgets.Layout(width=widget_width), ) # Create a series of HBox containers that will hold the widgets and labels @@ -1487,201 +2074,305 @@ def load(self, line='', local_ns: dict = None): # in two parts to support the validation steps that come later. In the case # of validation errors this allows additional text to easily be added to an # HBox describing the issue. - source_hbox_label = widgets.Label('Source:', - layout=widgets.Layout(width=label_width, - display="flex", - justify_content="flex-end")) + source_hbox_label = widgets.Label( + "Source:", + layout=widgets.Layout( + width=label_width, display="flex", justify_content="flex-end" + ), + ) source_hbox = widgets.HBox([source_hbox_label, source]) - format_hbox_label = widgets.Label('Format:', - layout=widgets.Layout(width=label_width, - display="flex", - justify_content="flex-end")) + format_hbox_label = widgets.Label( + "Format:", + layout=widgets.Layout( + width=label_width, display="flex", justify_content="flex-end" + ), + ) source_format_hbox = widgets.HBox([format_hbox_label, source_format]) - region_hbox = widgets.HBox([widgets.Label('Region:', - layout=widgets.Layout(width=label_width, - display="flex", - justify_content="flex-end")), - region_box]) + region_hbox = widgets.HBox( + [ + widgets.Label( + "Region:", + layout=widgets.Layout( + width=label_width, display="flex", justify_content="flex-end" + ), + ), + region_box, + ] + ) - arn_hbox_label = widgets.Label('Load ARN:', - layout=widgets.Layout(width=label_width, - display="flex", - justify_content="flex-end")) + arn_hbox_label = widgets.Label( + "Load ARN:", + layout=widgets.Layout( + width=label_width, display="flex", justify_content="flex-end" + ), + ) arn_hbox = widgets.HBox([arn_hbox_label, arn]) - mode_hbox = widgets.HBox([widgets.Label('Mode:', - layout=widgets.Layout(width=label_width, - display="flex", - justify_content="flex-end")), - mode]) + mode_hbox = widgets.HBox( + [ + widgets.Label( + "Mode:", + layout=widgets.Layout( + width=label_width, display="flex", justify_content="flex-end" + ), + ), + mode, + ] + ) - fail_hbox = widgets.HBox([widgets.Label('Fail on Error:', - layout=widgets.Layout(width=label_width, - display="flex", - justify_content="flex-end")), - fail_on_error]) + fail_hbox = widgets.HBox( + [ + widgets.Label( + "Fail on Error:", + layout=widgets.Layout( + width=label_width, display="flex", justify_content="flex-end" + ), + ), + fail_on_error, + ] + ) - parallelism_hbox = widgets.HBox([widgets.Label('Parallelism:', - layout=widgets.Layout(width=label_width, - display="flex", - justify_content="flex-end")), - parallelism]) + parallelism_hbox = widgets.HBox( + [ + widgets.Label( + "Parallelism:", + layout=widgets.Layout( + width=label_width, display="flex", justify_content="flex-end" + ), + ), + parallelism, + ] + ) - allow_empty_strings_hbox_label = widgets.Label('Allow Empty Strings:', - layout=widgets.Layout(width=label_width, - display=gremlin_parser_options_hbox_visibility, - justify_content="flex-end")) + allow_empty_strings_hbox_label = widgets.Label( + "Allow Empty Strings:", + layout=widgets.Layout( + width=label_width, + display=gremlin_parser_options_hbox_visibility, + justify_content="flex-end", + ), + ) - allow_empty_strings_hbox = widgets.HBox([allow_empty_strings_hbox_label, allow_empty_strings]) + allow_empty_strings_hbox = widgets.HBox( + [allow_empty_strings_hbox_label, allow_empty_strings] + ) - named_graph_uri_hbox_label = widgets.Label('Named Graph URI:', - layout=widgets.Layout(width=label_width, - display=named_graph_hbox_visibility, - justify_content="flex-end")) + named_graph_uri_hbox_label = widgets.Label( + "Named Graph URI:", + layout=widgets.Layout( + width=label_width, + display=named_graph_hbox_visibility, + justify_content="flex-end", + ), + ) - named_graph_uri_hbox = widgets.HBox([named_graph_uri_hbox_label, named_graph_uri]) + named_graph_uri_hbox = widgets.HBox( + [named_graph_uri_hbox_label, named_graph_uri] + ) - base_uri_hbox_label = widgets.Label('Base URI:', - layout=widgets.Layout(width=label_width, - display=base_uri_hbox_visibility, - justify_content="flex-end")) + base_uri_hbox_label = widgets.Label( + "Base URI:", + layout=widgets.Layout( + width=label_width, + display=base_uri_hbox_visibility, + justify_content="flex-end", + ), + ) base_uri_hbox = widgets.HBox([base_uri_hbox_label, base_uri]) - cardinality_hbox = widgets.HBox([widgets.Label('Update Single Cardinality:', - layout=widgets.Layout(width=label_width, - display="flex", - justify_content="flex-end")), - update_single_cardinality]) + cardinality_hbox = widgets.HBox( + [ + widgets.Label( + "Update Single Cardinality:", + layout=widgets.Layout( + width=label_width, display="flex", justify_content="flex-end" + ), + ), + update_single_cardinality, + ] + ) - queue_hbox = widgets.HBox([widgets.Label('Queue Request:', - layout=widgets.Layout(width=label_width, - display="flex", justify_content="flex-end")), - queue_request]) + queue_hbox = widgets.HBox( + [ + widgets.Label( + "Queue Request:", + layout=widgets.Layout( + width=label_width, display="flex", justify_content="flex-end" + ), + ), + queue_request, + ] + ) - dep_hbox_label = widgets.Label('Dependencies:', - layout=widgets.Layout(width=label_width, - display="flex", justify_content="flex-end")) + dep_hbox_label = widgets.Label( + "Dependencies:", + layout=widgets.Layout( + width=label_width, display="flex", justify_content="flex-end" + ), + ) dep_hbox = widgets.HBox([dep_hbox_label, dependencies]) - ids_hbox_label = widgets.Label('User Provided Edge Ids:', - layout=widgets.Layout(width=label_width, - display=ids_hbox_visibility, - justify_content="flex-end")) + ids_hbox_label = widgets.Label( + "User Provided Edge Ids:", + layout=widgets.Layout( + width=label_width, + display=ids_hbox_visibility, + justify_content="flex-end", + ), + ) ids_hbox = widgets.HBox([ids_hbox_label, user_provided_edge_ids]) - poll_status_label = widgets.Label('Poll Load Status:', - layout=widgets.Layout(width=label_width, - display="flex", - justify_content="flex-end")) + poll_status_label = widgets.Label( + "Poll Load Status:", + layout=widgets.Layout( + width=label_width, display="flex", justify_content="flex-end" + ), + ) poll_status_hbox = widgets.HBox([poll_status_label, poll_status]) def update_edge_ids_options(change): if change.new.lower() == FORMAT_OPENCYPHER: - ids_hbox_visibility = 'flex' + ids_hbox_visibility = "flex" else: - ids_hbox_visibility = 'none' - user_provided_edge_ids.value = 'TRUE' + ids_hbox_visibility = "none" + user_provided_edge_ids.value = "TRUE" user_provided_edge_ids.layout.display = ids_hbox_visibility ids_hbox_label.layout.display = ids_hbox_visibility def update_parserconfig_options(change): if change.new.lower() == FORMAT_CSV: - gremlin_parser_options_hbox_visibility = 'flex' - named_graph_hbox_visibility_hbox_visibility = 'none' - base_uri_hbox_visibility = 'none' - named_graph_uri.value = '' - base_uri.value = '' + gremlin_parser_options_hbox_visibility = "flex" + named_graph_hbox_visibility_hbox_visibility = "none" + base_uri_hbox_visibility = "none" + named_graph_uri.value = "" + base_uri.value = "" elif change.new.lower() == FORMAT_OPENCYPHER: - gremlin_parser_options_hbox_visibility = 'none' - allow_empty_strings.value = 'FALSE' - named_graph_hbox_visibility_hbox_visibility = 'none' - base_uri_hbox_visibility = 'none' - named_graph_uri.value = '' - base_uri.value = '' + gremlin_parser_options_hbox_visibility = "none" + allow_empty_strings.value = "FALSE" + named_graph_hbox_visibility_hbox_visibility = "none" + base_uri_hbox_visibility = "none" + named_graph_uri.value = "" + base_uri.value = "" else: - gremlin_parser_options_hbox_visibility = 'none' - allow_empty_strings.value = 'FALSE' - named_graph_hbox_visibility_hbox_visibility = 'flex' + gremlin_parser_options_hbox_visibility = "none" + allow_empty_strings.value = "FALSE" + named_graph_hbox_visibility_hbox_visibility = "flex" named_graph_uri.value = DEFAULT_NAMEDGRAPH_URI if change.new.lower() in BASE_URI_FORMATS: - base_uri_hbox_visibility = 'flex' + base_uri_hbox_visibility = "flex" base_uri.value = DEFAULT_BASE_URI else: - base_uri_hbox_visibility = 'none' - base_uri.value = '' + base_uri_hbox_visibility = "none" + base_uri.value = "" allow_empty_strings.layout.display = gremlin_parser_options_hbox_visibility - allow_empty_strings_hbox_label.layout.display = gremlin_parser_options_hbox_visibility + allow_empty_strings_hbox_label.layout.display = ( + gremlin_parser_options_hbox_visibility + ) named_graph_uri.layout.display = named_graph_hbox_visibility_hbox_visibility - named_graph_uri_hbox_label.layout.display = named_graph_hbox_visibility_hbox_visibility + named_graph_uri_hbox_label.layout.display = ( + named_graph_hbox_visibility_hbox_visibility + ) base_uri.layout.display = base_uri_hbox_visibility base_uri_hbox_label.layout.display = base_uri_hbox_visibility - source_format.observe(update_edge_ids_options, names='value') - source_format.observe(update_parserconfig_options, names='value') - - display(source_hbox, - source_format_hbox, - region_hbox, - arn_hbox, - mode_hbox, - fail_hbox, - parallelism_hbox, - cardinality_hbox, - queue_hbox, - dep_hbox, - poll_status_hbox, - ids_hbox, - allow_empty_strings_hbox, - named_graph_uri_hbox, - base_uri_hbox, - button, - output) + source_format.observe(update_edge_ids_options, names="value") + source_format.observe(update_parserconfig_options, names="value") + + display( + source_hbox, + source_format_hbox, + region_hbox, + arn_hbox, + mode_hbox, + fail_hbox, + parallelism_hbox, + cardinality_hbox, + queue_hbox, + dep_hbox, + poll_status_hbox, + ids_hbox, + allow_empty_strings_hbox, + named_graph_uri_hbox, + base_uri_hbox, + button, + output, + ) def on_button_clicked(b): - source_hbox.children = (source_hbox_label, source,) - arn_hbox.children = (arn_hbox_label, arn,) - source_format_hbox.children = (format_hbox_label, source_format,) - allow_empty_strings.children = (allow_empty_strings_hbox_label, allow_empty_strings,) - named_graph_uri_hbox.children = (named_graph_uri_hbox_label, named_graph_uri,) - base_uri_hbox.children = (base_uri_hbox_label, base_uri,) - dep_hbox.children = (dep_hbox_label, dependencies,) + source_hbox.children = ( + source_hbox_label, + source, + ) + arn_hbox.children = ( + arn_hbox_label, + arn, + ) + source_format_hbox.children = ( + format_hbox_label, + source_format, + ) + allow_empty_strings.children = ( + allow_empty_strings_hbox_label, + allow_empty_strings, + ) + named_graph_uri_hbox.children = ( + named_graph_uri_hbox_label, + named_graph_uri, + ) + base_uri_hbox.children = ( + base_uri_hbox_label, + base_uri, + ) + dep_hbox.children = ( + dep_hbox_label, + dependencies, + ) - dependencies_list = list(filter(None, dependencies.value.split('\n'))) + dependencies_list = list(filter(None, dependencies.value.split("\n"))) validated = True - validation_label_style = DescriptionStyle(color='red') - if not (source.value.startswith('s3://') and len(source.value) > 7) and not source.value.startswith('/'): + validation_label_style = DescriptionStyle(color="red") + if not ( + source.value.startswith("s3://") and len(source.value) > 7 + ) and not source.value.startswith("/"): validated = False source_validation_label = widgets.HTML( - '

Source must be an s3 bucket or file path

') + '

Source must be an s3 bucket or file path

' + ) source_validation_label.style = validation_label_style source_hbox.children += (source_validation_label,) - if source_format.value == '': + if source_format.value == "": validated = False - source_format_validation_label = widgets.HTML('

Format cannot be blank.

') + source_format_validation_label = widgets.HTML( + '

Format cannot be blank.

' + ) source_format_hbox.children += (source_format_validation_label,) - if not arn.value.startswith('arn:aws') and source.value.startswith( - "s3://"): # only do this validation if we are using an s3 bucket. + if not arn.value.startswith("arn:aws") and source.value.startswith( + "s3://" + ): # only do this validation if we are using an s3 bucket. validated = False - arn_validation_label = widgets.HTML('

Load ARN must start with "arn:aws"

') + arn_validation_label = widgets.HTML( + '

Load ARN must start with "arn:aws"

' + ) arn_hbox.children += (arn_validation_label,) if not len(dependencies_list) < 64: validated = False dep_validation_label = widgets.HTML( - '

A maximum of 64 jobs may be queued at once.

') + '

A maximum of 64 jobs may be queued at once.

' + ) dep_hbox.children += (dep_validation_label,) if not validated: @@ -1689,32 +2380,36 @@ def on_button_clicked(b): # replace any env variables in source.value with their values, can use $foo or ${foo}. # Particularly useful for ${AWS_REGION} - source_exp = os.path.expandvars( - source.value) - logger.info(f'using source_exp: {source_exp}') + source_exp = os.path.expandvars(source.value) + logger.info(f"using source_exp: {source_exp}") try: kwargs = { - 'failOnError': fail_on_error.value, - 'parallelism': parallelism.value, - 'updateSingleCardinalityProperties': update_single_cardinality.value, - 'queueRequest': queue_request.value, - 'region': region_box.value, - 'parserConfiguration': {} + "failOnError": fail_on_error.value, + "parallelism": parallelism.value, + "updateSingleCardinalityProperties": update_single_cardinality.value, + "queueRequest": queue_request.value, + "region": region_box.value, + "parserConfiguration": {}, } if dependencies: - kwargs['dependencies'] = dependencies_list + kwargs["dependencies"] = dependencies_list if source_format.value.lower() == FORMAT_OPENCYPHER: - kwargs['userProvidedEdgeIds'] = user_provided_edge_ids.value + kwargs["userProvidedEdgeIds"] = user_provided_edge_ids.value elif source_format.value.lower() == FORMAT_CSV: - if allow_empty_strings.value == 'TRUE': - kwargs['parserConfiguration']['allowEmptyStrings'] = True + if allow_empty_strings.value == "TRUE": + kwargs["parserConfiguration"]["allowEmptyStrings"] = True elif source_format.value.lower() in RDF_LOAD_FORMATS: if named_graph_uri.value: - kwargs['parserConfiguration']['namedGraphUri'] = named_graph_uri.value - if base_uri.value and source_format.value.lower() in BASE_URI_FORMATS: - kwargs['parserConfiguration']['baseUri'] = base_uri.value + kwargs["parserConfiguration"][ + "namedGraphUri" + ] = named_graph_uri.value + if ( + base_uri.value + and source_format.value.lower() in BASE_URI_FORMATS + ): + kwargs["parserConfiguration"]["baseUri"] = base_uri.value source_hbox.close() source_format_hbox.close() @@ -1742,9 +2437,13 @@ def on_button_clicked(b): display_html(HTML(loading_wheel_html)) try: if source.value.startswith("s3://"): - load_res = self.client.load(str(source_exp), source_format.value, arn.value, **kwargs) + load_res = self.client.load( + str(source_exp), source_format.value, arn.value, **kwargs + ) else: - load_res = self.client.load(str(source_exp), source_format.value, **kwargs) + load_res = self.client.load( + str(source_exp), source_format.value, **kwargs + ) except Exception as e: load_submit_status_output.clear_output() with output: @@ -1760,16 +2459,18 @@ def on_button_clicked(b): load_result = load_res.json() store_to_ns(args.store_to, load_result, local_ns) - if 'status' not in load_result or load_result['status'] != '200 OK': + if "status" not in load_result or load_result["status"] != "200 OK": with output: - print('Something went wrong.') + print("Something went wrong.") logger.error(load_result) return - if poll_status.value == 'FALSE': - start_msg_label = widgets.Label(f'Load started successfully!') - polling_msg_label = widgets.Label(f'You can run "%load_status {load_result["payload"]["loadId"]}" ' - f'in another cell to check the current status of your bulk load.') + if poll_status.value == "FALSE": + start_msg_label = widgets.Label(f"Load started successfully!") + polling_msg_label = widgets.Label( + f'You can run "%load_status {load_result["payload"]["loadId"]}" ' + f"in another cell to check the current status of your bulk load." + ) start_msg_hbox = widgets.HBox([start_msg_label]) polling_msg_hbox = widgets.HBox([polling_msg_label]) vbox = widgets.VBox([start_msg_hbox, polling_msg_hbox]) @@ -1777,7 +2478,9 @@ def on_button_clicked(b): display(vbox) else: poll_interval = 5 - load_id_label = widgets.Label(f'Load ID: {load_result["payload"]["loadId"]}') + load_id_label = widgets.Label( + f'Load ID: {load_result["payload"]["loadId"]}' + ) interval_output = widgets.Output() job_status_output = widgets.Output() load_id_hbox = widgets.HBox([load_id_label]) @@ -1794,36 +2497,55 @@ def on_button_clicked(b): interval_output.clear_output() if time_elapsed > poll_interval: with interval_output: - print('checking status...') + print("checking status...") job_status_output.clear_output() with job_status_output: display_html(HTML(loading_wheel_html)) new_interval = True try: - load_status_res = self.client.load_status(load_result['payload']['loadId']) + load_status_res = self.client.load_status( + load_result["payload"]["loadId"] + ) load_status_res.raise_for_status() interval_check_response = load_status_res.json() except Exception as e: logger.error(e) with job_status_output: - print('Something went wrong updating job status. Ending.') + print( + "Something went wrong updating job status. Ending." + ) return job_status_output.clear_output() with job_status_output: - print(f'Overall Status: ' - f'{interval_check_response["payload"]["overallStatus"]["status"]}') - if interval_check_response["payload"]["overallStatus"]["status"] in FINAL_LOAD_STATUSES: - execution_time = \ - interval_check_response["payload"]["overallStatus"]["totalTimeSpent"] + print( + f"Overall Status: " + f'{interval_check_response["payload"]["overallStatus"]["status"]}' + ) + if ( + interval_check_response["payload"]["overallStatus"][ + "status" + ] + in FINAL_LOAD_STATUSES + ): + execution_time = interval_check_response["payload"][ + "overallStatus" + ]["totalTimeSpent"] if execution_time == 0: - execution_time_statement = '<1 second' + execution_time_statement = "<1 second" elif execution_time > 59: - execution_time_statement = str(datetime.timedelta(seconds=execution_time)) + execution_time_statement = str( + datetime.timedelta(seconds=execution_time) + ) else: - execution_time_statement = f'{execution_time} seconds' - print('Total execution time: ' + execution_time_statement) + execution_time_statement = ( + f"{execution_time} seconds" + ) + print( + "Total execution time: " + + execution_time_statement + ) interval_output.close() - print('Done.') + print("Done.") return last_poll_time = time.time() else: @@ -1832,13 +2554,13 @@ def on_button_clicked(b): display_html(HTML(loading_wheel_html)) new_interval = False with interval_output: - print(f'checking status in {time_remaining} seconds') + print(f"checking status in {time_remaining} seconds") time.sleep(1) except HTTPError as httpEx: output.clear_output() with output: - print(httpEx.response.content.decode('utf-8')) + print(httpEx.response.content.decode("utf-8")) button.on_click(on_button_clicked) if args.run: @@ -1849,13 +2571,23 @@ def on_button_clicked(b): @needs_local_scope def load_ids(self, line, local_ns: dict = None): parser = argparse.ArgumentParser() - parser.add_argument('--details', action='store_true', default=False, - help="Display status details for each load job. Most recent jobs are listed first.") - parser.add_argument('--limit', type=int, default=maxsize, - help='If --details is True, only return the x most recent load job statuses. ' - 'Defaults to sys.maxsize.') - parser.add_argument('--silent', action='store_true', default=False, help="Display no output.") - parser.add_argument('--store-to', type=str, default='') + parser.add_argument( + "--details", + action="store_true", + default=False, + help="Display status details for each load job. Most recent jobs are listed first.", + ) + parser.add_argument( + "--limit", + type=int, + default=maxsize, + help="If --details is True, only return the x most recent load job statuses. " + "Defaults to sys.maxsize.", + ) + parser.add_argument( + "--silent", action="store_true", default=False, help="Display no output." + ) + parser.add_argument("--store-to", type=str, default="") args = parser.parse_args(line.split()) ids, res = get_load_ids(self.client) @@ -1889,28 +2621,38 @@ def load_ids(self, line, local_ns: dict = None): table_output = widgets.Output(layout=DEFAULT_LAYOUT) raw_output = widgets.Output(layout=DEFAULT_LAYOUT) tab.children = [table_output, raw_output] - tab.set_title(0, 'Table') - tab.set_title(1, 'Raw') + tab.set_title(0, "Table") + tab.set_title(1, "Raw") display(tab) results_df = pd.DataFrame(res_table) results_df.insert(0, "#", range(1, len(results_df) + 1)) with table_output: - show(results_df, - scrollX=True, - scrollY="475px", - columnDefs=[ - {"width": "5%", "targets": 0}, - {"className": "nowrap dt-left", "targets": "_all"}, - {"createdCell": JavascriptFunction(index_col_js), "targets": 0}, - {"createdCell": JavascriptFunction(cell_style_js), "targets": "_all"} - ], - paging=True, - scrollCollapse=True, - lengthMenu=[DEFAULT_PAGINATION_OPTIONS, DEFAULT_PAGINATION_MENU], - pageLength=10, - ) + show( + results_df, + scrollX=True, + scrollY="475px", + columnDefs=[ + {"width": "5%", "targets": 0}, + {"className": "nowrap dt-left", "targets": "_all"}, + { + "createdCell": JavascriptFunction(index_col_js), + "targets": 0, + }, + { + "createdCell": JavascriptFunction(cell_style_js), + "targets": "_all", + }, + ], + paging=True, + scrollCollapse=True, + lengthMenu=[ + DEFAULT_PAGINATION_OPTIONS, + DEFAULT_PAGINATION_MENU, + ], + pageLength=10, + ) with raw_output: print(json.dumps(res, indent=2)) @@ -1928,22 +2670,32 @@ def load_ids(self, line, local_ns: dict = None): @needs_local_scope def load_status(self, line, local_ns: dict = None): parser = argparse.ArgumentParser() - parser.add_argument('load_id', default='', help='loader id to check status for') - parser.add_argument('--silent', action='store_true', default=False, help="Display no output.") - parser.add_argument('--store-to', type=str, default='') - parser.add_argument('--details', action='store_true', default=False) - parser.add_argument('--errors', action='store_true', default=False) - parser.add_argument('--page', '-p', default='1', - help='The error page number. Only valid when the --errors option is set.') - parser.add_argument('--errorsPerPage', '-e', default='10', - help='The number of errors per each page. Only valid when the --errors option is set.') + parser.add_argument("load_id", default="", help="loader id to check status for") + parser.add_argument( + "--silent", action="store_true", default=False, help="Display no output." + ) + parser.add_argument("--store-to", type=str, default="") + parser.add_argument("--details", action="store_true", default=False) + parser.add_argument("--errors", action="store_true", default=False) + parser.add_argument( + "--page", + "-p", + default="1", + help="The error page number. Only valid when the --errors option is set.", + ) + parser.add_argument( + "--errorsPerPage", + "-e", + default="10", + help="The number of errors per each page. Only valid when the --errors option is set.", + ) args = parser.parse_args(line.split()) payload = { - 'details': args.details, - 'errors': args.errors, - 'page': args.page, - 'errorsPerPage': args.errorsPerPage + "details": args.details, + "errors": args.errors, + "page": args.page, + "errorsPerPage": args.errorsPerPage, } load_status_res = self.client.load_status(args.load_id, **payload) load_status_res.raise_for_status() @@ -1958,11 +2710,19 @@ def load_status(self, line, local_ns: dict = None): @needs_local_scope def cancel_load(self, line, local_ns: dict = None): parser = argparse.ArgumentParser() - parser.add_argument('load_id', nargs="?", default='', help='loader id to check status for') - parser.add_argument('--all-in-queue', action='store_true', default=False, - help="Cancel all load jobs with LOAD_IN_QUEUE status.") - parser.add_argument('--silent', action='store_true', default=False, help="Display no output.") - parser.add_argument('--store-to', type=str, default='') + parser.add_argument( + "load_id", nargs="?", default="", help="loader id to check status for" + ) + parser.add_argument( + "--all-in-queue", + action="store_true", + default=False, + help="Cancel all load jobs with LOAD_IN_QUEUE status.", + ) + parser.add_argument( + "--silent", action="store_true", default=False, help="Display no output." + ) + parser.add_argument("--store-to", type=str, default="") args = parser.parse_args(line.split()) loads_to_cancel = [] @@ -1989,10 +2749,10 @@ def cancel_load(self, line, local_ns: dict = None): res = cancel_res.json() if res: raw_res[load_id] = res - print_res[load_id] = 'Cancelled successfully.' + print_res[load_id] = "Cancelled successfully." else: - raw_res[load_id] = 'Something went wrong cancelling bulk load job.' - print_res[load_id] = 'Something went wrong cancelling bulk load job.' + raw_res[load_id] = "Something went wrong cancelling bulk load job." + print_res[load_id] = "Something went wrong cancelling bulk load job." if not args.silent: if print_res: @@ -2017,276 +2777,320 @@ def seed(self, line, local_ns: dict = None): or a directory containing multiple of these files. """ parser = argparse.ArgumentParser() - parser.add_argument('--model', type=str.lower, default='', - help='Specifies what data model you would like to load for. ' - 'Accepted values: property_graph, rdf') - parser.add_argument('--language', type=str.lower, default='', - help='Specifies what language you would like to load for. ' - 'Accepted values: gremlin, sparql, opencypher') - parser.add_argument('--dataset', type=str, default='', - help='Specifies what sample dataset you would like to load.') - parser.add_argument('--source', type=str, default='', - help='Specifies the full path to a local file or directory that you would like to ' - 'load from.') - parser.add_argument('-f', '--full-file-query', action='store_true', default=False, - help='Read all content of a file as a single query, instead of per line') + parser.add_argument( + "--model", + type=str.lower, + default="", + help="Specifies what data model you would like to load for. " + "Accepted values: property_graph, rdf", + ) + parser.add_argument( + "--language", + type=str.lower, + default="", + help="Specifies what language you would like to load for. " + "Accepted values: gremlin, sparql, opencypher", + ) + parser.add_argument( + "--dataset", + type=str, + default="", + help="Specifies what sample dataset you would like to load.", + ) + parser.add_argument( + "--source", + type=str, + default="", + help="Specifies the full path to a local file or directory that you would like to " + "load from.", + ) + parser.add_argument( + "-f", + "--full-file-query", + action="store_true", + default=False, + help="Read all content of a file as a single query, instead of per line", + ) # TODO: Gremlin api paths are not yet supported. - parser.add_argument('--path', '-p', default=SPARQL_ACTION, - help='prefix path to query endpoint. For example, "foo/bar". ' - 'The queried path would then be host:port/foo/bar for sparql seed commands') - parser.add_argument('--run', action='store_true') - parser.add_argument('--ignore-errors', action='store_true', default=False, - help='Continue loading from the seed file on failure of any individual query.') + parser.add_argument( + "--path", + "-p", + default=SPARQL_ACTION, + help='prefix path to query endpoint. For example, "foo/bar". ' + "The queried path would then be host:port/foo/bar for sparql seed commands", + ) + parser.add_argument("--run", action="store_true") + parser.add_argument( + "--ignore-errors", + action="store_true", + default=False, + help="Continue loading from the seed file on failure of any individual query.", + ) args = parser.parse_args(line.split()) output = widgets.Output() progress_output = widgets.Output() source_dropdown = widgets.Dropdown( - options=SEED_SOURCE_OPTIONS, - description='Source type:', - disabled=False + options=SEED_SOURCE_OPTIONS, description="Source type:", disabled=False ) model_dropdown = widgets.Dropdown( options=SEED_MODEL_OPTIONS, - description='Data model:', + description="Data model:", disabled=False, - layout=widgets.Layout(display='none') + layout=widgets.Layout(display="none"), ) custom_language_dropdown = widgets.Dropdown( options=SEED_LANGUAGE_OPTIONS, - description='Language:', + description="Language:", disabled=False, - layout=widgets.Layout(display='none') + layout=widgets.Layout(display="none"), ) samples_pg_language_dropdown = widgets.Dropdown( options=SEED_LANGUAGE_OPTIONS[:3], - description='Language:', + description="Language:", disabled=False, - layout=widgets.Layout(display='none') + layout=widgets.Layout(display="none"), ) data_set_drop_down = widgets.Dropdown( - description='Data set:', + description="Data set:", disabled=False, - layout=widgets.Layout(display='none') + layout=widgets.Layout(display="none"), ) fullfile_option_dropdown = widgets.Dropdown( - description='Full File Query:', + description="Full File Query:", options=[True, False], value=args.full_file_query, disabled=False, - layout=widgets.Layout(display='none') + layout=widgets.Layout(display="none"), ) location_option_dropdown = widgets.Dropdown( - description='Location:', - options=['Local', 'S3'], - value='Local', + description="Location:", + options=["Local", "S3"], + value="Local", disabled=False, - layout=widgets.Layout(display='none') + layout=widgets.Layout(display="none"), ) seed_file_location_text = widgets.Text( - description='Source:', - placeholder='path/to/seedfiles/directory', - disabled=False + description="Source:", + placeholder="path/to/seedfiles/directory", + disabled=False, ) seed_file_location = FileChooser() - seed_file_location.layout.display = 'none' + seed_file_location.layout.display = "none" seed_file_location_text_hbox = widgets.HBox([seed_file_location_text]) submit_button = widgets.Button(description="Submit") - model_dropdown.layout.visibility = 'hidden' - custom_language_dropdown.layout.visibility = 'hidden' - samples_pg_language_dropdown.layout.visibility = 'hidden' - data_set_drop_down.layout.visibility = 'hidden' - fullfile_option_dropdown.layout.visibility = 'hidden' - location_option_dropdown.layout.visibility = 'hidden' - seed_file_location_text_hbox.layout.visibility = 'hidden' - seed_file_location.layout.visibility = 'hidden' - submit_button.layout.visibility = 'hidden' + model_dropdown.layout.visibility = "hidden" + custom_language_dropdown.layout.visibility = "hidden" + samples_pg_language_dropdown.layout.visibility = "hidden" + data_set_drop_down.layout.visibility = "hidden" + fullfile_option_dropdown.layout.visibility = "hidden" + location_option_dropdown.layout.visibility = "hidden" + seed_file_location_text_hbox.layout.visibility = "hidden" + seed_file_location.layout.visibility = "hidden" + submit_button.layout.visibility = "hidden" def hide_all_widgets(): - location_option_dropdown.layout.visibility = 'hidden' - location_option_dropdown.layout.display = 'none' - seed_file_location_text_hbox.layout.visibility = 'hidden' - seed_file_location_text_hbox.layout.display = 'none' - custom_language_dropdown.layout.visibility = 'hidden' - custom_language_dropdown.layout.display = 'none' - samples_pg_language_dropdown.layout.visibility = 'hidden' - samples_pg_language_dropdown.layout.display = 'none' - fullfile_option_dropdown.layout.visibility = 'hidden' - fullfile_option_dropdown.layout.display = 'none' - seed_file_location.layout.visibility = 'hidden' - seed_file_location.layout.display = 'none' - seed_file_location_text_hbox.layout.visibility = 'hidden' - seed_file_location_text_hbox.layout.display = 'none' - model_dropdown.layout.visibility = 'hidden' - model_dropdown.layout.display = 'none' - data_set_drop_down.layout.visibility = 'hidden' - data_set_drop_down.layout.display = 'none' - submit_button.layout.visibility = 'hidden' + location_option_dropdown.layout.visibility = "hidden" + location_option_dropdown.layout.display = "none" + seed_file_location_text_hbox.layout.visibility = "hidden" + seed_file_location_text_hbox.layout.display = "none" + custom_language_dropdown.layout.visibility = "hidden" + custom_language_dropdown.layout.display = "none" + samples_pg_language_dropdown.layout.visibility = "hidden" + samples_pg_language_dropdown.layout.display = "none" + fullfile_option_dropdown.layout.visibility = "hidden" + fullfile_option_dropdown.layout.display = "none" + seed_file_location.layout.visibility = "hidden" + seed_file_location.layout.display = "none" + seed_file_location_text_hbox.layout.visibility = "hidden" + seed_file_location_text_hbox.layout.display = "none" + model_dropdown.layout.visibility = "hidden" + model_dropdown.layout.display = "none" + data_set_drop_down.layout.visibility = "hidden" + data_set_drop_down.layout.display = "none" + submit_button.layout.visibility = "hidden" def on_source_value_change(change): hide_all_widgets() - selected_source = change['new'] - if selected_source == 'custom': - custom_language_dropdown.layout.visibility = 'visible' - custom_language_dropdown.layout.display = 'flex' - location_option_dropdown.layout.visibility = 'visible' - location_option_dropdown.layout.display = 'flex' + selected_source = change["new"] + if selected_source == "custom": + custom_language_dropdown.layout.visibility = "visible" + custom_language_dropdown.layout.display = "flex" + location_option_dropdown.layout.visibility = "visible" + location_option_dropdown.layout.display = "flex" if custom_language_dropdown.value: - if custom_language_dropdown.value != 'sparql': - fullfile_option_dropdown.layout.visibility = 'visible' - fullfile_option_dropdown.layout.display = 'flex' + if custom_language_dropdown.value != "sparql": + fullfile_option_dropdown.layout.visibility = "visible" + fullfile_option_dropdown.layout.display = "flex" # If textbox has a value, OR we are loading from S3, display textbox instead of the filepicker - if seed_file_location_text.value or location_option_dropdown.value == 'S3': - seed_file_location_text_hbox.layout.visibility = 'visible' - seed_file_location_text_hbox.layout.display = 'flex' - elif seed_file_location.value or location_option_dropdown.value == 'Local': - seed_file_location.layout.visibility = 'visible' - seed_file_location.layout.display = 'flex' - if custom_language_dropdown.value \ - and (seed_file_location_text.value or - (seed_file_location.value and location_option_dropdown.value == 'Local')): - submit_button.layout.visibility = 'visible' - elif selected_source == 'samples': - custom_language_dropdown.layout.visibility = 'hidden' - custom_language_dropdown.layout.display = 'none' - fullfile_option_dropdown.layout.visibility = 'hidden' - fullfile_option_dropdown.layout.display = 'none' - seed_file_location.layout.visibility = 'hidden' - seed_file_location.layout.display = 'none' - model_dropdown.layout.visibility = 'visible' - model_dropdown.layout.display = 'flex' + if ( + seed_file_location_text.value + or location_option_dropdown.value == "S3" + ): + seed_file_location_text_hbox.layout.visibility = "visible" + seed_file_location_text_hbox.layout.display = "flex" + elif ( + seed_file_location.value + or location_option_dropdown.value == "Local" + ): + seed_file_location.layout.visibility = "visible" + seed_file_location.layout.display = "flex" + if custom_language_dropdown.value and ( + seed_file_location_text.value + or ( + seed_file_location.value + and location_option_dropdown.value == "Local" + ) + ): + submit_button.layout.visibility = "visible" + elif selected_source == "samples": + custom_language_dropdown.layout.visibility = "hidden" + custom_language_dropdown.layout.display = "none" + fullfile_option_dropdown.layout.visibility = "hidden" + fullfile_option_dropdown.layout.display = "none" + seed_file_location.layout.visibility = "hidden" + seed_file_location.layout.display = "none" + model_dropdown.layout.visibility = "visible" + model_dropdown.layout.display = "flex" if model_dropdown.value: show_dataset = False - if model_dropdown.value == 'propertygraph': - samples_pg_language_dropdown.layout.visibility = 'visible' - samples_pg_language_dropdown.layout.display = 'flex' - if samples_pg_language_dropdown.value != '': + if model_dropdown.value == "propertygraph": + samples_pg_language_dropdown.layout.visibility = "visible" + samples_pg_language_dropdown.layout.display = "flex" + if samples_pg_language_dropdown.value != "": show_dataset = True else: - samples_pg_language_dropdown.layout.visibility = 'hidden' - samples_pg_language_dropdown.layout.display = 'none' + samples_pg_language_dropdown.layout.visibility = "hidden" + samples_pg_language_dropdown.layout.display = "none" show_dataset = True if show_dataset: - data_set_drop_down.layout.visibility = 'visible' - data_set_drop_down.layout.display = 'flex' - if data_set_drop_down.value and data_set_drop_down.value != SEED_NO_DATASETS_FOUND_MSG: - submit_button.layout.visibility = 'visible' + data_set_drop_down.layout.visibility = "visible" + data_set_drop_down.layout.display = "flex" + if ( + data_set_drop_down.value + and data_set_drop_down.value != SEED_NO_DATASETS_FOUND_MSG + ): + submit_button.layout.visibility = "visible" else: - custom_language_dropdown.layout.visibility = 'hidden' - custom_language_dropdown.layout.display = 'none' - samples_pg_language_dropdown.layout.visibility = 'hidden' - samples_pg_language_dropdown.layout.display = 'none' - fullfile_option_dropdown.layout.visibility = 'hidden' - fullfile_option_dropdown.layout.display = 'none' - seed_file_location.layout.visibility = 'hidden' - seed_file_location.layout.display = 'none' - seed_file_location_text.layout.visibility = 'hidden' - seed_file_location_text.layout.display = 'none' - model_dropdown.layout.visibility = 'hidden' - model_dropdown.layout.display = 'none' - data_set_drop_down.layout.visibility = 'hidden' - data_set_drop_down.layout.display = 'none' + custom_language_dropdown.layout.visibility = "hidden" + custom_language_dropdown.layout.display = "none" + samples_pg_language_dropdown.layout.visibility = "hidden" + samples_pg_language_dropdown.layout.display = "none" + fullfile_option_dropdown.layout.visibility = "hidden" + fullfile_option_dropdown.layout.display = "none" + seed_file_location.layout.visibility = "hidden" + seed_file_location.layout.display = "none" + seed_file_location_text.layout.visibility = "hidden" + seed_file_location_text.layout.display = "none" + model_dropdown.layout.visibility = "hidden" + model_dropdown.layout.display = "none" + data_set_drop_down.layout.visibility = "hidden" + data_set_drop_down.layout.display = "none" return def change_datasets_widget(samples_lang): data_sets = get_data_sets(samples_lang) if data_sets: data_sets.sort() - data_set_drop_down.options = [ds for ds in data_sets if - ds != '__pycache__'] # being extra sure that we aren't passing __pycache__. - data_set_drop_down.layout.visibility = 'visible' - data_set_drop_down.layout.display = 'flex' - submit_button.layout.visibility = 'visible' + data_set_drop_down.options = [ + ds for ds in data_sets if ds != "__pycache__" + ] # being extra sure that we aren't passing __pycache__. + data_set_drop_down.layout.visibility = "visible" + data_set_drop_down.layout.display = "flex" + submit_button.layout.visibility = "visible" else: if samples_lang: data_set_drop_down.options = [SEED_NO_DATASETS_FOUND_MSG] - data_set_drop_down.layout.visibility = 'visible' - data_set_drop_down.layout.display = 'flex' + data_set_drop_down.layout.visibility = "visible" + data_set_drop_down.layout.display = "flex" else: - data_set_drop_down.layout.visibility = 'hidden' - data_set_drop_down.layout.display = 'none' - submit_button.layout.visibility = 'hidden' + data_set_drop_down.layout.visibility = "hidden" + data_set_drop_down.layout.display = "none" + submit_button.layout.visibility = "hidden" return def on_model_value_change(change): - selected_model = change['new'] - samples_language = '' - if selected_model == 'propertygraph': - samples_pg_language_dropdown.layout.visibility = 'visible' - samples_pg_language_dropdown.layout.display = 'flex' - if samples_pg_language_dropdown.value != '': + selected_model = change["new"] + samples_language = "" + if selected_model == "propertygraph": + samples_pg_language_dropdown.layout.visibility = "visible" + samples_pg_language_dropdown.layout.display = "flex" + if samples_pg_language_dropdown.value != "": samples_language = samples_pg_language_dropdown.value else: - samples_pg_language_dropdown.layout.visibility = 'hidden' - samples_pg_language_dropdown.layout.display = 'none' - if selected_model == 'rdf': - samples_language = 'sparql' + samples_pg_language_dropdown.layout.visibility = "hidden" + samples_pg_language_dropdown.layout.display = "none" + if selected_model == "rdf": + samples_language = "sparql" change_datasets_widget(samples_language) return def on_dataset_value_change(change): - selected_dataset = change['new'] + selected_dataset = change["new"] if not selected_dataset: - submit_button.layout.visibility = 'hidden' + submit_button.layout.visibility = "hidden" return def on_samples_pg_language_value_change(change): - selected_pg_language = change['new'] + selected_pg_language = change["new"] change_datasets_widget(selected_pg_language) return def on_custom_language_value_change(change): # Preserve the value/state of the text/selector widget if it's already rendered # Otherwise, display the default selector widget (file browser) - selected_language = change['new'] - if selected_language != 'sparql': - fullfile_option_dropdown.layout.visibility = 'visible' - fullfile_option_dropdown.layout.display = 'flex' + selected_language = change["new"] + if selected_language != "sparql": + fullfile_option_dropdown.layout.visibility = "visible" + fullfile_option_dropdown.layout.display = "flex" else: - fullfile_option_dropdown.layout.visibility = 'hidden' - fullfile_option_dropdown.layout.display = 'none' - if not seed_file_location_text.value and seed_file_location_text_hbox.layout.visibility == 'hidden': - seed_file_location.layout.visibility = 'visible' - seed_file_location.layout.display = 'flex' - submit_button.layout.visibility = 'visible' + fullfile_option_dropdown.layout.visibility = "hidden" + fullfile_option_dropdown.layout.display = "none" + if ( + not seed_file_location_text.value + and seed_file_location_text_hbox.layout.visibility == "hidden" + ): + seed_file_location.layout.visibility = "visible" + seed_file_location.layout.display = "flex" + submit_button.layout.visibility = "visible" return def on_location_value_change(change): - selected_location = change['new'] - if selected_location == 'Local' and not seed_file_location_text.value: - seed_file_location_text_hbox.layout.visibility = 'hidden' - seed_file_location_text_hbox.layout.display = 'none' - seed_file_location.layout.visibility = 'visible' - seed_file_location.layout.display = 'flex' + selected_location = change["new"] + if selected_location == "Local" and not seed_file_location_text.value: + seed_file_location_text_hbox.layout.visibility = "hidden" + seed_file_location_text_hbox.layout.display = "none" + seed_file_location.layout.visibility = "visible" + seed_file_location.layout.display = "flex" else: - seed_file_location.layout.visibility = 'hidden' - seed_file_location.layout.display = 'none' - seed_file_location_text_hbox.layout.visibility = 'visible' - seed_file_location_text_hbox.layout.display = 'flex' + seed_file_location.layout.visibility = "hidden" + seed_file_location.layout.display = "none" + seed_file_location_text_hbox.layout.visibility = "visible" + seed_file_location_text_hbox.layout.display = "flex" return def on_seedfile_text_value_change(change): if seed_file_location_text.value: - submit_button.layout.visibility = 'visible' + submit_button.layout.visibility = "visible" else: - submit_button.layout.visibility = 'hidden' + submit_button.layout.visibility = "hidden" return def on_seedfile_select_value_change(change): if seed_file_location.value: - submit_button.layout.visibility = 'visible' + submit_button.layout.visibility = "visible" else: - submit_button.layout.visibility = 'hidden' + submit_button.layout.visibility = "hidden" return def disable_seed_widgets(): @@ -2307,43 +3111,49 @@ def process_gremlin_query_line(query_line, line_index, q): # return 1 = continue, set any_errors_flag = True, error_count += 1 # return 2 = progress.close() and return, set any_errors_flag = True, error_count += 1 if not query_line: - logger.debug(f"Skipped blank query at line {line_index + 1} in seed file {q['name']}") + logger.debug( + f"Skipped blank query at line {line_index + 1} in seed file {q['name']}" + ) return 0 try: self.client.gremlin_query(query_line) return 0 except GremlinServerError as gremlinEx: try: - error = json.loads(gremlinEx.args[0][5:]) # remove the leading error code. + error = json.loads( + gremlinEx.args[0][5:] + ) # remove the leading error code. content = json.dumps(error, indent=2) except Exception: - content = { - 'error': gremlinEx - } - logger.debug(f"GremlinServerError at line {line_index + 1} in seed file {q['name']}") + content = {"error": gremlinEx} + logger.debug( + f"GremlinServerError at line {line_index + 1} in seed file {q['name']}" + ) logger.debug(content) if args.ignore_errors: return 1 else: with output: - generate_seed_error_msg(content, q['name'], line_index + 1) + generate_seed_error_msg(content, q["name"], line_index + 1) return 2 except Exception as e: - content = { - 'error': e - } - logger.debug(f"Exception at line {line_index + 1} in seed file {q['name']}") + content = {"error": e} + logger.debug( + f"Exception at line {line_index + 1} in seed file {q['name']}" + ) logger.debug(content) if args.ignore_errors: return 1 else: with output: - generate_seed_error_msg(content, q['name'], line_index + 1) + generate_seed_error_msg(content, q["name"], line_index + 1) return 2 def process_cypher_query_line(query_line, line_index, q): if not query_line: - logger.debug(f"Skipped blank query at line {line_index + 1} in seed file {q['name']}") + logger.debug( + f"Skipped blank query at line {line_index + 1} in seed file {q['name']}" + ) return 0 try: cypher_res = self.client.opencypher_http(query_line) @@ -2351,53 +3161,59 @@ def process_cypher_query_line(query_line, line_index, q): return 0 except HTTPError as httpEx: try: - error = json.loads(httpEx.response.content.decode('utf-8')) + error = json.loads(httpEx.response.content.decode("utf-8")) content = json.dumps(error, indent=2) except Exception: - content = { - 'error': httpEx - } + content = {"error": httpEx} logger.debug(content) if args.ignore_errors: return 1 else: with output: - generate_seed_error_msg(content, q['name']) + generate_seed_error_msg(content, q["name"]) return 2 except Exception as ex: - content = { - 'error': str(ex) - } + content = {"error": str(ex)} logger.error(content) if args.ignore_errors: return 1 else: with output: - generate_seed_error_msg(content, q['name']) + generate_seed_error_msg(content, q["name"]) return 2 def on_button_clicked(b=None): seed_file_location_text_hbox.children = (seed_file_location_text,) filename = None - if source_dropdown.value == 'samples': + if source_dropdown.value == "samples": data_set = data_set_drop_down.value.lower() fullfile_query = False else: if seed_file_location_text.value: stall_with_warning = False - if location_option_dropdown.value == 'S3' and not (seed_file_location_text.value.startswith('s3://') - and len(seed_file_location_text.value) > 7): + if location_option_dropdown.value == "S3" and not ( + seed_file_location_text.value.startswith("s3://") + and len(seed_file_location_text.value) > 7 + ): seed_file_location_text_validation_label = widgets.HTML( - '

S3 source URI must start with s3://

') + '

S3 source URI must start with s3://

' + ) stall_with_warning = True - elif location_option_dropdown.value == 'Local' \ - and not seed_file_location_text.value.startswith('/'): + elif ( + location_option_dropdown.value == "Local" + and not seed_file_location_text.value.startswith("/") + ): seed_file_location_text_validation_label = widgets.HTML( - '

Local source URI must be a valid file path

') + '

Local source URI must be a valid file path

' + ) stall_with_warning = True if stall_with_warning: - seed_file_location_text_validation_label.style = DescriptionStyle(color='red') - seed_file_location_text_hbox.children += (seed_file_location_text_validation_label,) + seed_file_location_text_validation_label.style = ( + DescriptionStyle(color="red") + ) + seed_file_location_text_hbox.children += ( + seed_file_location_text_validation_label, + ) return filename = seed_file_location_text.value elif seed_file_location.value: @@ -2409,21 +3225,25 @@ def on_button_clicked(b=None): disable_seed_widgets() if custom_language_dropdown.value and filename: model = normalize_model_name(custom_language_dropdown.value) - seeding_language = normalize_language_name(custom_language_dropdown.value) + seeding_language = normalize_language_name( + custom_language_dropdown.value + ) else: model = normalize_model_name(model_dropdown.value) - seeding_language = 'sparql' if model == 'rdf' else samples_pg_language_dropdown.value + seeding_language = ( + "sparql" if model == "rdf" else samples_pg_language_dropdown.value + ) with output: - print(f'Loading data set {data_set} for {seeding_language}') + print(f"Loading data set {data_set} for {seeding_language}") queries = get_queries(seeding_language, data_set, source_dropdown.value) if queries: if len(queries) < 1: with output: - print('Did not find any queries for the given dataset') + print("Did not find any queries for the given dataset") return else: with output: - print('Query retrieval from files terminated with errors.') + print("Query retrieval from files terminated with errors.") return load_index = 1 # start at 1 to have a non-empty progress bar @@ -2431,9 +3251,9 @@ def on_button_clicked(b=None): value=load_index, min=0, max=len(queries) + 1, # len + 1 so we can start at index 1 - orientation='horizontal', - bar_style='info', - description='Loading:' + orientation="horizontal", + bar_style="info", + description="Loading:", ) with progress_output: @@ -2444,50 +3264,48 @@ def on_button_clicked(b=None): for q in queries: with output: print(f'{progress.value}/{len(queries)}:\t{q["name"]}') - if model == 'rdf': + if model == "rdf": try: - self.client.sparql(q['content'], path=args.path) + self.client.sparql(q["content"], path=args.path) except HTTPError as httpEx: # attempt to turn response into json try: - error = json.loads(httpEx.response.content.decode('utf-8')) + error = json.loads(httpEx.response.content.decode("utf-8")) content = json.dumps(error, indent=2) except Exception: any_errors_flag = True error_count += 1 - content = { - 'error': httpEx - } + content = {"error": httpEx} logger.debug(content) if args.ignore_errors: progress.value += 1 continue else: with output: - generate_seed_error_msg(content, q['name']) + generate_seed_error_msg(content, q["name"]) progress.close() return except Exception as ex: any_errors_flag = True error_count += 1 - content = { - 'error': str(ex) - } + content = {"error": str(ex)} logger.error(content) if args.ignore_errors: progress.value += 1 continue else: with output: - generate_seed_error_msg(content, q['name']) + generate_seed_error_msg(content, q["name"]) progress.close() return else: # gremlin and cypher if fullfile_query: # treat entire file content as one query - if seeding_language == 'opencypher': - query_status = process_cypher_query_line(q['content'], 0, q) + if seeding_language == "opencypher": + query_status = process_cypher_query_line(q["content"], 0, q) else: - query_status = process_gremlin_query_line(q['content'], 0, q) + query_status = process_gremlin_query_line( + q["content"], 0, q + ) if query_status == 2: progress.close() return @@ -2498,11 +3316,17 @@ def on_button_clicked(b=None): progress.value += 1 continue else: # treat each line as its own query - for line_index, query_line in enumerate(q['content'].splitlines()): - if seeding_language == 'opencypher': - query_status = process_cypher_query_line(query_line, line_index, q) + for line_index, query_line in enumerate( + q["content"].splitlines() + ): + if seeding_language == "opencypher": + query_status = process_cypher_query_line( + query_line, line_index, q + ) else: - query_status = process_gremlin_query_line(query_line, line_index, q) + query_status = process_gremlin_query_line( + query_line, line_index, q + ) if query_status == 2: progress.close() return @@ -2516,75 +3340,93 @@ def on_button_clicked(b=None): time.sleep(2) progress.close() with output: - print('Done.') + print("Done.") if any_errors_flag: - print(f'\n{error_count} individual queries were skipped due to errors. For more ' - f'information, please rerun the query with debug logs enabled (%enable_debug).') + print( + f"\n{error_count} individual queries were skipped due to errors. For more " + f"information, please rerun the query with debug logs enabled (%enable_debug)." + ) return submit_button.on_click(on_button_clicked) - source_dropdown.observe(on_source_value_change, names='value') - model_dropdown.observe(on_model_value_change, names='value') - data_set_drop_down.observe(on_dataset_value_change, names='value') - custom_language_dropdown.observe(on_custom_language_value_change, names='value') - samples_pg_language_dropdown.observe(on_samples_pg_language_value_change, names='value') - location_option_dropdown.observe(on_location_value_change, names='value') - seed_file_location_text.observe(on_seedfile_text_value_change, names='value') - seed_file_location.observe(on_seedfile_select_value_change, names='value') - - display(source_dropdown, model_dropdown, custom_language_dropdown, samples_pg_language_dropdown, - data_set_drop_down, fullfile_option_dropdown, location_option_dropdown, seed_file_location, - seed_file_location_text_hbox, submit_button, progress_output, output) - - if (args.model != '' or args.language != '') and args.source == '': - source_dropdown.value = 'samples' + source_dropdown.observe(on_source_value_change, names="value") + model_dropdown.observe(on_model_value_change, names="value") + data_set_drop_down.observe(on_dataset_value_change, names="value") + custom_language_dropdown.observe(on_custom_language_value_change, names="value") + samples_pg_language_dropdown.observe( + on_samples_pg_language_value_change, names="value" + ) + location_option_dropdown.observe(on_location_value_change, names="value") + seed_file_location_text.observe(on_seedfile_text_value_change, names="value") + seed_file_location.observe(on_seedfile_select_value_change, names="value") + + display( + source_dropdown, + model_dropdown, + custom_language_dropdown, + samples_pg_language_dropdown, + data_set_drop_down, + fullfile_option_dropdown, + location_option_dropdown, + seed_file_location, + seed_file_location_text_hbox, + submit_button, + progress_output, + output, + ) + + if (args.model != "" or args.language != "") and args.source == "": + source_dropdown.value = "samples" normed_model = normalize_model_name(args.model) normed_language = normalize_language_name(args.language) selected_model = None selected_language = None - if normed_model != '' and normed_model in SEED_MODEL_OPTIONS: - if normed_model == 'propertygraph': - selected_model = 'propertygraph' - if normed_language in ['gremlin', 'opencypher']: + if normed_model != "" and normed_model in SEED_MODEL_OPTIONS: + if normed_model == "propertygraph": + selected_model = "propertygraph" + if normed_language in ["gremlin", "opencypher"]: selected_language = normed_language - elif normed_language == '': - selected_language = 'gremlin' + elif normed_language == "": + selected_language = "gremlin" else: - selected_model = 'rdf' - selected_language = 'sparql' - elif normed_language != '' and normed_language in SEED_LANGUAGE_OPTIONS: - if normed_language == 'sparql': - selected_model = 'rdf' - selected_language = 'sparql' + selected_model = "rdf" + selected_language = "sparql" + elif normed_language != "" and normed_language in SEED_LANGUAGE_OPTIONS: + if normed_language == "sparql": + selected_model = "rdf" + selected_language = "sparql" else: - selected_model = 'propertygraph' + selected_model = "propertygraph" selected_language = normed_language if selected_model: model_dropdown.value = selected_model if selected_language: - if selected_language != 'sparql': + if selected_language != "sparql": samples_pg_language_dropdown.value = selected_language - if args.dataset != '' and args.dataset in data_set_drop_down.options: + if ( + args.dataset != "" + and args.dataset in data_set_drop_down.options + ): data_set_drop_down.value = args.dataset.lower() if args.run: on_button_clicked() - elif args.source != '' or args.language != '': - source_dropdown.value = 'custom' + elif args.source != "" or args.language != "": + source_dropdown.value = "custom" valid_language_value = False language = normalize_language_name(args.language) - if language != '' and language in SEED_LANGUAGE_OPTIONS: + if language != "" and language in SEED_LANGUAGE_OPTIONS: custom_language_dropdown.value = language valid_language_value = True - if args.source != '': + if args.source != "": seed_file_location_text.value = args.source - seed_file_location_text_hbox.layout.visibility = 'visible' - seed_file_location_text_hbox.layout.display = 'flex' - if seed_file_location_text.value.startswith('s3://'): - location_option_dropdown.value = 'S3' - location_option_dropdown.layout.visibility = 'visible' - location_option_dropdown.layout.display = 'flex' - seed_file_location.layout.visibility = 'hidden' - seed_file_location.layout.display = 'none' + seed_file_location_text_hbox.layout.visibility = "visible" + seed_file_location_text_hbox.layout.display = "flex" + if seed_file_location_text.value.startswith("s3://"): + location_option_dropdown.value = "S3" + location_option_dropdown.layout.visibility = "visible" + location_option_dropdown.layout.display = "flex" + seed_file_location.layout.visibility = "hidden" + seed_file_location.layout.display = "none" if seed_file_location_text.value and valid_language_value and args.run: on_button_clicked() @@ -2601,7 +3443,7 @@ def disable_debug(self, line): @line_magic @needs_local_scope def toggle_traceback(self, line, local_ns: dict = None): - show_traceback_ns_var = 'graph_notebook_show_traceback' + show_traceback_ns_var = "graph_notebook_show_traceback" try: show_traceback = local_ns[show_traceback_ns_var] if not isinstance(show_traceback, bool): @@ -2611,7 +3453,9 @@ def toggle_traceback(self, line, local_ns: dict = None): except KeyError: show_traceback = True - print(f"Display of tracebacks from magics is toggled {'ON' if show_traceback else 'OFF'}.") + print( + f"Display of tracebacks from magics is toggled {'ON' if show_traceback else 'OFF'}." + ) store_to_ns(show_traceback_ns_var, show_traceback, local_ns) @line_magic @@ -2622,25 +3466,37 @@ def graph_notebook_version(self, line): @line_cell_magic @display_exceptions @needs_local_scope - def graph_notebook_vis_options(self, line='', cell='', local_ns: dict = None): + def graph_notebook_vis_options(self, line="", cell="", local_ns: dict = None): parser = argparse.ArgumentParser() - parser.add_argument('--silent', action='store_true', default=False, help="Display no output.") - parser.add_argument('--store-to', type=str, default='', help='store visualization settings to this variable') - parser.add_argument('--load-from', type=str, default='', help='load visualization settings from this variable') + parser.add_argument( + "--silent", action="store_true", default=False, help="Display no output." + ) + parser.add_argument( + "--store-to", + type=str, + default="", + help="store visualization settings to this variable", + ) + parser.add_argument( + "--load-from", + type=str, + default="", + help="load visualization settings from this variable", + ) line_args = line.split() if line_args: - if line_args[0] == 'reset': - line = 'reset' + if line_args[0] == "reset": + line = "reset" if len(line_args) > 1: line_args = line_args[1:] else: line_args = [] args = parser.parse_args(line_args) - if line == 'reset': + if line == "reset": self.graph_notebook_vis_options = OPTIONS_DEFAULT_DIRECTED - if cell == '' and not args.load_from: + if cell == "" and not args.load_from: if not args.silent: print(json.dumps(self.graph_notebook_vis_options, indent=2)) else: @@ -2652,30 +3508,42 @@ def graph_notebook_vis_options(self, line='', cell='', local_ns: dict = None): options_raw = json.dumps(options_raw) options_dict = json.loads(options_raw) except KeyError: - print(f"Unable to load visualization settings, variable [{args.load_from}] does not exist in " - f"the local namespace.") + print( + f"Unable to load visualization settings, variable [{args.load_from}] does not exist in " + f"the local namespace." + ) return else: options_dict = json.loads(cell) except (JSONDecodeError, TypeError) as e: - print(f"Unable to load visualization settings, variable [{args.load_from}] is not in valid JSON " - f"format:\n") + print( + f"Unable to load visualization settings, variable [{args.load_from}] is not in valid JSON " + f"format:\n" + ) print(e) return - self.graph_notebook_vis_options = vis_options_merge(self.graph_notebook_vis_options, options_dict) + self.graph_notebook_vis_options = vis_options_merge( + self.graph_notebook_vis_options, options_dict + ) print("Visualization settings successfully changed to:\n") print(json.dumps(self.graph_notebook_vis_options, indent=2)) - store_to_ns(args.store_to, json.dumps(self.graph_notebook_vis_options, indent=2), local_ns) + store_to_ns( + args.store_to, + json.dumps(self.graph_notebook_vis_options, indent=2), + local_ns, + ) @magic_variables @line_cell_magic @display_exceptions @needs_local_scope - def neptune_ml(self, line, cell='', local_ns: dict = None): + def neptune_ml(self, line, cell="", local_ns: dict = None): parser = generate_neptune_ml_parser() args = parser.parse_args(line.split()) - logger.info(f'received call to neptune_ml with details: {args.__dict__}, cell={cell}, local_ns={local_ns}') + logger.info( + f"received call to neptune_ml with details: {args.__dict__}, cell={cell}, local_ns={local_ns}" + ) main_output = widgets.Output() display(main_output) res = neptune_ml_magic_handler(args, self.client, main_output, cell) @@ -2689,47 +3557,144 @@ def handle_opencypher_query(self, line, cell, local_ns): This method in its own handler so that the magics %%opencypher and %%oc can both call it """ parser = argparse.ArgumentParser() - parser.add_argument('--explain-type', type=str.lower, default='dynamic', - help=f'Explain mode to use when using the explain query mode. ' - f'Accepted values: ${OPENCYPHER_EXPLAIN_MODES}') - parser.add_argument('-qp', '--query-parameters', type=str, default='', - help='Parameter definitions to apply to the query. This option can accept a local variable ' - 'name, or a string representation of the map.') - parser.add_argument('-g', '--group-by', type=str, default='~labels', - help='Property used to group nodes (e.g. code, ~id) default is ~labels') - parser.add_argument('-gd', '--group-by-depth', action='store_true', default=False, - help="Group nodes based on path hierarchy") - parser.add_argument('-gr', '--group-by-raw', action='store_true', default=False, - help="Group nodes by the raw result") - parser.add_argument('mode', nargs='?', default='query', help='query mode [query|bolt|explain]', - choices=['query', 'bolt', 'explain']) - parser.add_argument('-d', '--display-property', type=str, default='~labels', - help='Property to display the value of on each node, default is ~labels') - parser.add_argument('-de', '--edge-display-property', type=str, default='~labels', - help='Property to display the value of on each edge, default is ~type') - parser.add_argument('-t', '--tooltip-property', type=str, default='', - help='Property to display the value of on each node tooltip. If not specified, tooltip ' - 'will default to the node label value.') - parser.add_argument('-te', '--edge-tooltip-property', type=str, default='', - help='Property to display the value of on each edge tooltip. If not specified, tooltip ' - 'will default to the edge label value.') - parser.add_argument('-l', '--label-max-length', type=int, default=10, - help='Specifies max length of vertex label, in characters. Default is 10') - parser.add_argument('-rel', '--rel-label-max-length', type=int, default=10, - help='Specifies max length of edge labels, in characters. Default is 10') - parser.add_argument('--store-to', type=str, default='', help='store query result to this variable') - parser.add_argument('--ignore-groups', action='store_true', default=False, help="Ignore all grouping options") - parser.add_argument('-sp', '--stop-physics', action='store_true', default=False, - help="Disable visualization physics after the initial simulation stabilizes.") - parser.add_argument('-sd', '--simulation-duration', type=int, default=1500, - help='Specifies maximum duration of visualization physics simulation. Default is 1500ms') - parser.add_argument('--silent', action='store_true', default=False, help="Display no query output.") - parser.add_argument('-r', '--results-per-page', type=int, default=10, - help='Specifies how many query results to display per page in the output. Default is 10') - parser.add_argument('--no-scroll', action='store_true', default=False, - help="Display the entire output without a scroll bar.") - parser.add_argument('--hide-index', action='store_true', default=False, - help="Hide the index column numbers when displaying the results.") + parser.add_argument( + "--explain-type", + type=str.lower, + default="dynamic", + help=f"Explain mode to use when using the explain query mode. " + f"Accepted values: ${OPENCYPHER_EXPLAIN_MODES}", + ) + parser.add_argument( + "-qp", + "--query-parameters", + type=str, + default="", + help="Parameter definitions to apply to the query. This option can accept a local variable " + "name, or a string representation of the map.", + ) + parser.add_argument( + "-g", + "--group-by", + type=str, + default="~labels", + help="Property used to group nodes (e.g. code, ~id) default is ~labels", + ) + parser.add_argument( + "-gd", + "--group-by-depth", + action="store_true", + default=False, + help="Group nodes based on path hierarchy", + ) + parser.add_argument( + "-gr", + "--group-by-raw", + action="store_true", + default=False, + help="Group nodes by the raw result", + ) + parser.add_argument( + "mode", + nargs="?", + default="query", + help="query mode [query|bolt|explain]", + choices=["query", "bolt", "explain"], + ) + parser.add_argument( + "-d", + "--display-property", + type=str, + default="~labels", + help="Property to display the value of on each node, default is ~labels", + ) + parser.add_argument( + "-de", + "--edge-display-property", + type=str, + default="~labels", + help="Property to display the value of on each edge, default is ~type", + ) + parser.add_argument( + "-t", + "--tooltip-property", + type=str, + default="", + help="Property to display the value of on each node tooltip. If not specified, tooltip " + "will default to the node label value.", + ) + parser.add_argument( + "-te", + "--edge-tooltip-property", + type=str, + default="", + help="Property to display the value of on each edge tooltip. If not specified, tooltip " + "will default to the edge label value.", + ) + parser.add_argument( + "-l", + "--label-max-length", + type=int, + default=10, + help="Specifies max length of vertex label, in characters. Default is 10", + ) + parser.add_argument( + "-rel", + "--rel-label-max-length", + type=int, + default=10, + help="Specifies max length of edge labels, in characters. Default is 10", + ) + parser.add_argument( + "--store-to", + type=str, + default="", + help="store query result to this variable", + ) + parser.add_argument( + "--ignore-groups", + action="store_true", + default=False, + help="Ignore all grouping options", + ) + parser.add_argument( + "-sp", + "--stop-physics", + action="store_true", + default=False, + help="Disable visualization physics after the initial simulation stabilizes.", + ) + parser.add_argument( + "-sd", + "--simulation-duration", + type=int, + default=1500, + help="Specifies maximum duration of visualization physics simulation. Default is 1500ms", + ) + parser.add_argument( + "--silent", + action="store_true", + default=False, + help="Display no query output.", + ) + parser.add_argument( + "-r", + "--results-per-page", + type=int, + default=10, + help="Specifies how many query results to display per page in the output. Default is 10", + ) + parser.add_argument( + "--no-scroll", + action="store_true", + default=False, + help="Display the entire output without a scroll bar.", + ) + parser.add_argument( + "--hide-index", + action="store_true", + default=False, + help="Hide the index column numbers when displaying the results.", + ) args = parser.parse_args(line.split()) logger.debug(args) res = None @@ -2768,24 +3733,35 @@ def handle_opencypher_query(self, line, cell, local_ns): first_tab_output = widgets.Output(layout=oc_layout) children.append(first_tab_output) - if args.mode == 'explain': - query_start = time.time() * 1000 # time.time() returns time in seconds w/high precision; x1000 to get in ms - res = self.client.opencypher_http(cell, explain=args.explain_type, query_params=query_params) + if args.mode == "explain": + query_start = ( + time.time() * 1000 + ) # time.time() returns time in seconds w/high precision; x1000 to get in ms + res = self.client.opencypher_http( + cell, explain=args.explain_type, query_params=query_params + ) query_time = time.time() * 1000 - query_start explain = res.content.decode("utf-8") res.raise_for_status() ##store_to_ns(args.store_to, explain, local_ns) if not args.silent: - oc_metadata = build_opencypher_metadata_from_query(query_type='explain', results=None, - results_type='explain', res=res, - query_time=query_time) - titles.append('Explain') - explain_bytes = explain.encode('utf-8') - base64_str = base64.b64encode(explain_bytes).decode('utf-8') - first_tab_html = opencypher_explain_template.render(table=explain, - link=f"data:text/html;base64,{base64_str}") - elif args.mode == 'query': - query_start = time.time() * 1000 # time.time() returns time in seconds w/high precision; x1000 to get in ms + oc_metadata = build_opencypher_metadata_from_query( + query_type="explain", + results=None, + results_type="explain", + res=res, + query_time=query_time, + ) + titles.append("Explain") + explain_bytes = explain.encode("utf-8") + base64_str = base64.b64encode(explain_bytes).decode("utf-8") + first_tab_html = opencypher_explain_template.render( + table=explain, link=f"data:text/html;base64,{base64_str}" + ) + elif args.mode == "query": + query_start = ( + time.time() * 1000 + ) # time.time() returns time in seconds w/high precision; x1000 to get in ms oc_http = self.client.opencypher_http(cell, query_params=query_params) query_time = time.time() * 1000 - query_start oc_http.raise_for_status() @@ -2804,45 +3780,61 @@ def handle_opencypher_query(self, line, cell, local_ns): res_format = "jolt" if not args.silent: - oc_metadata = build_opencypher_metadata_from_query(query_type='query', results=res, - results_type=res_format, query_time=query_time) + oc_metadata = build_opencypher_metadata_from_query( + query_type="query", + results=res, + results_type=res_format, + query_time=query_time, + ) first_tab_html = "" rows_and_columns = opencypher_get_rows_and_columns(res, res_format) if rows_and_columns: - titles.append('Console') - results_df = pd.DataFrame(rows_and_columns['rows']) + titles.append("Console") + results_df = pd.DataFrame(rows_and_columns["rows"]) results_df = results_df.astype(str) results_df = results_df.applymap(lambda x: replace_html_chars(x)) results_df.insert(0, "#", range(1, len(results_df) + 1)) - for col_index, col_name in enumerate(rows_and_columns['columns']): - results_df.rename({results_df.columns[col_index + 1]: col_name}, - axis='columns', - inplace=True) + for col_index, col_name in enumerate(rows_and_columns["columns"]): + results_df.rename( + {results_df.columns[col_index + 1]: col_name}, + axis="columns", + inplace=True, + ) try: - gn = OCNetwork(group_by_property=args.group_by, display_property=args.display_property, - group_by_raw=args.group_by_raw, - group_by_depth=args.group_by_depth, - edge_display_property=args.edge_display_property, - tooltip_property=args.tooltip_property, - edge_tooltip_property=args.edge_tooltip_property, - label_max_length=args.label_max_length, - edge_label_max_length=args.rel_label_max_length, - ignore_groups=args.ignore_groups) + gn = OCNetwork( + group_by_property=args.group_by, + display_property=args.display_property, + group_by_raw=args.group_by_raw, + group_by_depth=args.group_by_depth, + edge_display_property=args.edge_display_property, + tooltip_property=args.tooltip_property, + edge_tooltip_property=args.edge_tooltip_property, + label_max_length=args.label_max_length, + edge_label_max_length=args.rel_label_max_length, + ignore_groups=args.ignore_groups, + ) gn.add_results(res) - logger.debug(f'number of nodes is {len(gn.graph.nodes)}') + logger.debug(f"number of nodes is {len(gn.graph.nodes)}") if len(gn.graph.nodes) > 0: - self.graph_notebook_vis_options['physics']['disablePhysicsAfterInitialSimulation'] \ - = args.stop_physics - self.graph_notebook_vis_options['physics']['simulationDuration'] = args.simulation_duration - force_graph_output = Force(network=gn, options=self.graph_notebook_vis_options) - titles.append('Graph') + self.graph_notebook_vis_options["physics"][ + "disablePhysicsAfterInitialSimulation" + ] = args.stop_physics + self.graph_notebook_vis_options["physics"][ + "simulationDuration" + ] = args.simulation_duration + force_graph_output = Force( + network=gn, options=self.graph_notebook_vis_options + ) + titles.append("Graph") children.append(force_graph_output) except (TypeError, ValueError) as network_creation_error: - logger.debug(f'Unable to create network from result. Skipping from result set: {res}') - logger.debug(f'Error: {network_creation_error}') + logger.debug( + f"Unable to create network from result. Skipping from result set: {res}" + ) + logger.debug(f"Error: {network_creation_error}") - elif args.mode == 'bolt': - res_format = 'bolt' + elif args.mode == "bolt": + res_format = "bolt" query_start = time.time() * 1000 if query_params: res = self.client.opencyper_bolt(cell, **query_params) @@ -2850,38 +3842,46 @@ def handle_opencypher_query(self, line, cell, local_ns): res = self.client.opencyper_bolt(cell) query_time = time.time() * 1000 - query_start if not args.silent: - oc_metadata = build_opencypher_metadata_from_query(query_type='bolt', results=res, - results_type=res_format, query_time=query_time) + oc_metadata = build_opencypher_metadata_from_query( + query_type="bolt", + results=res, + results_type=res_format, + query_time=query_time, + ) first_tab_html = "" rows_and_columns = opencypher_get_rows_and_columns(res, res_format) if rows_and_columns: - titles.append('Console') - results_df = pd.DataFrame(rows_and_columns['rows']) + titles.append("Console") + results_df = pd.DataFrame(rows_and_columns["rows"]) results_df = results_df.astype(str) results_df = results_df.applymap(lambda x: replace_html_chars(x)) results_df.insert(0, "#", range(1, len(results_df) + 1)) - for col_index, col_name in enumerate(rows_and_columns['columns']): - results_df.rename({results_df.columns[col_index + 1]: col_name}, - axis='columns', - inplace=True) + for col_index, col_name in enumerate(rows_and_columns["columns"]): + results_df.rename( + {results_df.columns[col_index + 1]: col_name}, + axis="columns", + inplace=True, + ) # Need to eventually add code to parse and display a network for the bolt format here if not args.silent: - if args.mode != 'explain': + if args.mode != "explain": # Display JSON tab json_output = widgets.Output(layout=oc_layout) with json_output: print(json.dumps(res, indent=2)) children.append(json_output) - titles.append('JSON') + titles.append("JSON") # Display Query Metadata Tab metadata_output = widgets.Output(layout=oc_layout) - titles.append('Query Metadata') + titles.append("Query Metadata") children.append(metadata_output) if first_tab_html == "" and results_df is None: - tab.children = children[1:] # the first tab is empty, remove it and proceed + tab.children = children[ + 1: + ] # the first tab is empty, remove it and proceed else: tab.children = children @@ -2894,27 +3894,34 @@ def handle_opencypher_query(self, line, cell, local_ns): if results_df is not None: with first_tab_output: - visible_results, final_pagination_options, final_pagination_menu = generate_pagination_vars( - args.results_per_page) + ( + visible_results, + final_pagination_options, + final_pagination_menu, + ) = generate_pagination_vars(args.results_per_page) oc_columndefs = [ {"width": "5%", "targets": 0}, {"visible": True, "targets": 0}, {"searchable": False, "targets": 0}, {"className": "nowrap dt-left", "targets": "_all"}, {"createdCell": JavascriptFunction(index_col_js), "targets": 0}, - {"createdCell": JavascriptFunction(cell_style_js), "targets": "_all", } + { + "createdCell": JavascriptFunction(cell_style_js), + "targets": "_all", + }, ] if args.hide_index: oc_columndefs[1]["visible"] = False - show(results_df, - scrollX=True, - scrollY=oc_scrollY, - columnDefs=oc_columndefs, - paging=oc_paging, - scrollCollapse=oc_scrollCollapse, - lengthMenu=[final_pagination_options, final_pagination_menu], - pageLength=visible_results - ) + show( + results_df, + scrollX=True, + scrollY=oc_scrollY, + columnDefs=oc_columndefs, + paging=oc_paging, + scrollCollapse=oc_scrollCollapse, + lengthMenu=[final_pagination_options, final_pagination_menu], + pageLength=visible_results, + ) elif first_tab_html != "": with first_tab_output: display(HTML(first_tab_html)) @@ -2927,21 +3934,47 @@ def handle_opencypher_status(self, line, local_ns): %opencypher_status or from %oc_status """ parser = argparse.ArgumentParser() - parser.add_argument('-q', '--queryId', default='', - help='The ID of a running OpenCypher query. ' - 'Only displays the status of the specified query.') - parser.add_argument('-c', '--cancelQuery', action='store_true', default=False, - help='Tells the status command to cancel a query. This parameter does not take a value.') - parser.add_argument('-w', '--includeWaiting', action='store_true', default=False, - help='When set to true and other parameters are not present, causes status information ' - 'for waiting queries to be returned as well as for running queries. ' - 'This parameter does not take a value.') - parser.add_argument('-s', '--silent-cancel', action='store_true', default=False, - help='If silent_cancel=true then the running query is cancelled and the HTTP response ' - 'code is 200. If silent_cancel is not present or silent_cancel=false, ' - 'the query is cancelled with an HTTP 500 status code.') - parser.add_argument('--silent', action='store_true', default=False, help="Display no output.") - parser.add_argument('--store-to', type=str, default='', help='store query result to this variable') + parser.add_argument( + "-q", + "--queryId", + default="", + help="The ID of a running OpenCypher query. " + "Only displays the status of the specified query.", + ) + parser.add_argument( + "-c", + "--cancelQuery", + action="store_true", + default=False, + help="Tells the status command to cancel a query. This parameter does not take a value.", + ) + parser.add_argument( + "-w", + "--includeWaiting", + action="store_true", + default=False, + help="When set to true and other parameters are not present, causes status information " + "for waiting queries to be returned as well as for running queries. " + "This parameter does not take a value.", + ) + parser.add_argument( + "-s", + "--silent-cancel", + action="store_true", + default=False, + help="If silent_cancel=true then the running query is cancelled and the HTTP response " + "code is 200. If silent_cancel is not present or silent_cancel=false, " + "the query is cancelled with an HTTP 500 status code.", + ) + parser.add_argument( + "--silent", action="store_true", default=False, help="Display no output." + ) + parser.add_argument( + "--store-to", + type=str, + default="", + help="store query result to this variable", + ) args = parser.parse_args(line.split()) if not args.cancelQuery: @@ -2951,7 +3984,7 @@ def handle_opencypher_status(self, line, local_ns): res = self.client.opencypher_status(query_id=args.queryId) res.raise_for_status() else: - if args.queryId == '': + if args.queryId == "": if not args.silent: print(OPENCYPHER_CANCEL_HINT_MSG) return diff --git a/src/graph_notebook/neptune/client.py b/src/graph_notebook/neptune/client.py index b4e868c6..1bfe6678 100644 --- a/src/graph_notebook/neptune/client.py +++ b/src/graph_notebook/neptune/client.py @@ -29,93 +29,130 @@ # client >= 3.5.0 as the HashableDict is now part of that client driver. # import graph_notebook.neptune.gremlin.graphsonV3d0_MapType_objectify_patch # noqa F401 -DEFAULT_GREMLIN_SERIALIZER = 'graphsonv3' -DEFAULT_GREMLIN_TRAVERSAL_SOURCE = 'g' -DEFAULT_SPARQL_CONTENT_TYPE = 'application/x-www-form-urlencoded' +DEFAULT_GREMLIN_SERIALIZER = "graphsonv3" +DEFAULT_GREMLIN_TRAVERSAL_SOURCE = "g" +DEFAULT_SPARQL_CONTENT_TYPE = "application/x-www-form-urlencoded" DEFAULT_PORT = 8182 -DEFAULT_REGION = 'us-east-1' -DEFAULT_NEO4J_USERNAME = 'neo4j' -DEFAULT_NEO4J_PASSWORD = 'password' +DEFAULT_REGION = "us-east-1" +DEFAULT_NEO4J_USERNAME = "neo4j" +DEFAULT_NEO4J_PASSWORD = "password" DEFAULT_NEO4J_DATABASE = DEFAULT_DATABASE +DEFAULT_MEMGRAPH_USERNAME = "" +DEFAULT_MEMGRAPH_PASSWORD = "" -NEPTUNE_SERVICE_NAME = 'neptune-db' -logger = logging.getLogger('client') +NEPTUNE_SERVICE_NAME = "neptune-db" +logger = logging.getLogger("client") # TODO: Constants for states of each long-running job # TODO: add doc links to each command -FORMAT_CSV = 'csv' -FORMAT_OPENCYPHER='opencypher' -FORMAT_NTRIPLE = 'ntriples' -FORMAT_NQUADS = 'nquads' -FORMAT_RDFXML = 'rdfxml' -FORMAT_TURTLE = 'turtle' +FORMAT_CSV = "csv" +FORMAT_OPENCYPHER = "opencypher" +FORMAT_NTRIPLE = "ntriples" +FORMAT_NQUADS = "nquads" +FORMAT_RDFXML = "rdfxml" +FORMAT_TURTLE = "turtle" -PARALLELISM_LOW = 'LOW' -PARALLELISM_MEDIUM = 'MEDIUM' -PARALLELISM_HIGH = 'HIGH' -PARALLELISM_OVERSUBSCRIBE = 'OVERSUBSCRIBE' +PARALLELISM_LOW = "LOW" +PARALLELISM_MEDIUM = "MEDIUM" +PARALLELISM_HIGH = "HIGH" +PARALLELISM_OVERSUBSCRIBE = "OVERSUBSCRIBE" -MODE_RESUME = 'RESUME' -MODE_NEW = 'NEW' -MODE_AUTO = 'AUTO' +MODE_RESUME = "RESUME" +MODE_NEW = "NEW" +MODE_AUTO = "AUTO" LOAD_JOB_MODES = [MODE_RESUME, MODE_NEW, MODE_AUTO] -VALID_FORMATS = [FORMAT_CSV, FORMAT_OPENCYPHER, FORMAT_NTRIPLE, FORMAT_NQUADS, FORMAT_RDFXML, FORMAT_TURTLE] -PARALLELISM_OPTIONS = [PARALLELISM_LOW, PARALLELISM_MEDIUM, PARALLELISM_HIGH, PARALLELISM_OVERSUBSCRIBE] -LOADER_ACTION = 'loader' - -FINAL_LOAD_STATUSES = ['LOAD_COMPLETED', - 'LOAD_COMMITTED_W_WRITE_CONFLICTS', - 'LOAD_CANCELLED_BY_USER', - 'LOAD_CANCELLED_DUE_TO_ERRORS', - 'LOAD_FAILED', - 'LOAD_UNEXPECTED_ERROR', - 'LOAD_DATA_DEADLOCK', - 'LOAD_DATA_FAILED_DUE_TO_FEED_MODIFIED_OR_DELETED', - 'LOAD_S3_READ_ERROR', - 'LOAD_S3_ACCESS_DENIED_ERROR', - 'LOAD_IN_QUEUE', - 'LOAD_FAILED_BECAUSE_DEPENDENCY_NOT_SATISFIED', - 'LOAD_FAILED_INVALID_REQUEST', ] - -EXPORT_SERVICE_NAME = 'execute-api' -EXPORT_ACTION = 'neptune-export' -EXTRA_HEADERS = {'content-type': 'application/json'} -SPARQL_ACTION = 'sparql' +VALID_FORMATS = [ + FORMAT_CSV, + FORMAT_OPENCYPHER, + FORMAT_NTRIPLE, + FORMAT_NQUADS, + FORMAT_RDFXML, + FORMAT_TURTLE, +] +PARALLELISM_OPTIONS = [ + PARALLELISM_LOW, + PARALLELISM_MEDIUM, + PARALLELISM_HIGH, + PARALLELISM_OVERSUBSCRIBE, +] +LOADER_ACTION = "loader" + +FINAL_LOAD_STATUSES = [ + "LOAD_COMPLETED", + "LOAD_COMMITTED_W_WRITE_CONFLICTS", + "LOAD_CANCELLED_BY_USER", + "LOAD_CANCELLED_DUE_TO_ERRORS", + "LOAD_FAILED", + "LOAD_UNEXPECTED_ERROR", + "LOAD_DATA_DEADLOCK", + "LOAD_DATA_FAILED_DUE_TO_FEED_MODIFIED_OR_DELETED", + "LOAD_S3_READ_ERROR", + "LOAD_S3_ACCESS_DENIED_ERROR", + "LOAD_IN_QUEUE", + "LOAD_FAILED_BECAUSE_DEPENDENCY_NOT_SATISFIED", + "LOAD_FAILED_INVALID_REQUEST", +] + +EXPORT_SERVICE_NAME = "execute-api" +EXPORT_ACTION = "neptune-export" +EXTRA_HEADERS = {"content-type": "application/json"} +SPARQL_ACTION = "sparql" # Constants used by the Stream Viewer. -STREAM_AT = 'AT_SEQUENCE_NUMBER' -STREAM_AFTER = 'AFTER_SEQUENCE_NUMBER' -STREAM_TRIM = 'TRIM_HORIZON' -STREAM_LATEST = 'LATEST' -STREAM_COMMIT_TIMESTAMP = 'commitTimestamp' -STREAM_IS_LASTOP = 'isLastOp' -STREAM_EXCEPTION_NOT_FOUND = 'StreamRecordsNotFoundException' -STREAM_EXCEPTION_NOT_ENABLED = 'UnsupportedOperationException' +STREAM_AT = "AT_SEQUENCE_NUMBER" +STREAM_AFTER = "AFTER_SEQUENCE_NUMBER" +STREAM_TRIM = "TRIM_HORIZON" +STREAM_LATEST = "LATEST" +STREAM_COMMIT_TIMESTAMP = "commitTimestamp" +STREAM_IS_LASTOP = "isLastOp" +STREAM_EXCEPTION_NOT_FOUND = "StreamRecordsNotFoundException" +STREAM_EXCEPTION_NOT_ENABLED = "UnsupportedOperationException" # A mapping from the name in the stream_viewer widget dropdown, to the actual Neptune # Streams endpoint (API) name. We do not map 'PropertyGraph' to 'pg' to maintain # compatability with older engine releases that did not have a 'pg' endpoint. -STREAM_PG = 'PropertyGraph' -STREAM_RDF = 'RDF' -STREAM_ENDPOINTS = {STREAM_PG: 'gremlin', STREAM_RDF: 'sparql'} - -NEPTUNE_CONFIG_HOST_IDENTIFIERS = ["neptune.amazonaws.com", "neptune.*.amazonaws.com.cn", "api.aws", "on.aws", "aws.dev"] - -false_str_variants = [False, 'False', 'false', 'FALSE'] - -GRAPHSONV3_VARIANTS = ['graphsonv3', 'graphsonv3d0', 'graphsonserializersv3d0'] -GRAPHSONV2_VARIANTS = ['graphsonv2', 'graphsonv2d0', 'graphsonserializersv2d0'] -GRAPHBINARYV1_VARIANTS = ['graphbinaryv1', 'graphbinary', 'graphbinaryserializersv1'] - -STATISTICS_MODES = ["", "status", "disableAutoCompute", "enableAutoCompute", "refresh", "delete"] +STREAM_PG = "PropertyGraph" +STREAM_RDF = "RDF" +STREAM_ENDPOINTS = {STREAM_PG: "gremlin", STREAM_RDF: "sparql"} + +NEPTUNE_CONFIG_HOST_IDENTIFIERS = [ + "neptune.amazonaws.com", + "neptune.*.amazonaws.com.cn", + "api.aws", + "on.aws", + "aws.dev", +] + +false_str_variants = [False, "False", "false", "FALSE"] + +GRAPHSONV3_VARIANTS = ["graphsonv3", "graphsonv3d0", "graphsonserializersv3d0"] +GRAPHSONV2_VARIANTS = ["graphsonv2", "graphsonv2d0", "graphsonserializersv2d0"] +GRAPHBINARYV1_VARIANTS = ["graphbinaryv1", "graphbinary", "graphbinaryserializersv1"] + +STATISTICS_MODES = [ + "", + "status", + "disableAutoCompute", + "enableAutoCompute", + "refresh", + "delete", +] SUMMARY_MODES = ["", "basic", "detailed"] -STATISTICS_LANGUAGE_INPUTS = ["propertygraph", "pg", "gremlin", "oc", "opencypher", "sparql", "rdf"] +STATISTICS_LANGUAGE_INPUTS = [ + "propertygraph", + "pg", + "gremlin", + "oc", + "opencypher", + "sparql", + "rdf", +] -SPARQL_EXPLAIN_MODES = ['dynamic', 'static', 'details'] -OPENCYPHER_EXPLAIN_MODES = ['dynamic', 'static', 'details'] +SPARQL_EXPLAIN_MODES = ["dynamic", "static", "details"] +OPENCYPHER_EXPLAIN_MODES = ["dynamic", "static", "details"] def is_allowed_neptune_host(hostname: str, host_allowlist: list): @@ -127,25 +164,40 @@ def is_allowed_neptune_host(hostname: str, host_allowlist: list): def get_gremlin_serializer(serializer_str: str): serializer_lower = serializer_str.lower() - if serializer_lower == 'graphbinaryv1': + if serializer_lower == "graphbinaryv1": return serializer.GraphBinarySerializersV1() - elif serializer_lower == 'graphsonv2': + elif serializer_lower == "graphsonv2": return serializer.GraphSONSerializersV2d0() else: return serializer.GraphSONSerializersV3d0() class Client(object): - def __init__(self, host: str, port: int = DEFAULT_PORT, ssl: bool = True, ssl_verify: bool = True, - region: str = DEFAULT_REGION, sparql_path: str = '/sparql', - gremlin_traversal_source: str = DEFAULT_GREMLIN_TRAVERSAL_SOURCE, - gremlin_username: str = '', gremlin_password: str = '', - gremlin_serializer: str = DEFAULT_GREMLIN_SERIALIZER, - neo4j_username: str = DEFAULT_NEO4J_USERNAME, neo4j_password: str = DEFAULT_NEO4J_PASSWORD, - neo4j_auth: bool = True, neo4j_database: str = DEFAULT_NEO4J_DATABASE, - auth=None, session: Session = None, - proxy_host: str = '', proxy_port: int = DEFAULT_PORT, - neptune_hosts: list = None): + def __init__( + self, + host: str, + port: int = DEFAULT_PORT, + ssl: bool = True, + ssl_verify: bool = True, + region: str = DEFAULT_REGION, + sparql_path: str = "/sparql", + gremlin_traversal_source: str = DEFAULT_GREMLIN_TRAVERSAL_SOURCE, + gremlin_username: str = "", + gremlin_password: str = "", + gremlin_serializer: str = DEFAULT_GREMLIN_SERIALIZER, + neo4j_username: str = DEFAULT_NEO4J_USERNAME, + neo4j_password: str = DEFAULT_NEO4J_PASSWORD, + neo4j_auth: bool = True, + neo4j_database: str = DEFAULT_NEO4J_DATABASE, + memgraph_username: str = DEFAULT_MEMGRAPH_USERNAME, + memgraph_password: str = DEFAULT_MEMGRAPH_PASSWORD, + memgraph_auth: bool = False, + auth=None, + session: Session = None, + proxy_host: str = "", + proxy_port: int = DEFAULT_PORT, + neptune_hosts: list = None, + ): self.target_host = host self.target_port = port self.ssl = ssl @@ -161,116 +213,147 @@ def __init__(self, host: str, port: int = DEFAULT_PORT, ssl: bool = True, ssl_ve self.neo4j_password = neo4j_password self.neo4j_auth = neo4j_auth self.neo4j_database = neo4j_database + self.memgraph_username = memgraph_username + self.memgraph_password = memgraph_password + self.memgraph_auth = memgraph_auth self.region = region self._auth = auth self._session = session self.proxy_host = proxy_host self.proxy_port = proxy_port - self.neptune_hosts = NEPTUNE_CONFIG_HOST_IDENTIFIERS if neptune_hosts is None else neptune_hosts + self.neptune_hosts = ( + NEPTUNE_CONFIG_HOST_IDENTIFIERS if neptune_hosts is None else neptune_hosts + ) - self._http_protocol = 'https' if self.ssl else 'http' - self._ws_protocol = 'wss' if self.ssl else 'ws' + self._http_protocol = "https" if self.ssl else "http" + self._ws_protocol = "wss" if self.ssl else "ws" self._http_session = None @property def host(self): - if self.proxy_host != '': + if self.proxy_host != "": return self.proxy_host return self.target_host @property def port(self): - if self.proxy_host != '': + if self.proxy_host != "": return self.proxy_port return self.target_port def is_neptune_domain(self): - return is_allowed_neptune_host(hostname=self.target_host, host_allowlist=self.neptune_hosts) + return is_allowed_neptune_host( + hostname=self.target_host, host_allowlist=self.neptune_hosts + ) def get_uri_with_port(self, use_websocket=False, use_proxy=False): protocol = self._http_protocol if use_websocket is True: protocol = self._ws_protocol - uri = f'{protocol}://{self.host}:{self.port}' + uri = f"{protocol}://{self.host}:{self.port}" return uri - def sparql_query(self, query: str, headers=None, explain: str = '', path: str = '') -> requests.Response: + def sparql_query( + self, query: str, headers=None, explain: str = "", path: str = "" + ) -> requests.Response: if headers is None: headers = {} - data = {'query': query} + data = {"query": query} return self.do_sparql_request(data, headers, explain, path=path) - def sparql_update(self, update: str, headers=None, explain: str = '', path: str = '') -> requests.Response: + def sparql_update( + self, update: str, headers=None, explain: str = "", path: str = "" + ) -> requests.Response: if headers is None: headers = {} - data = {'update': update} + data = {"update": update} return self.do_sparql_request(data, headers, explain, path=path) - def do_sparql_request(self, data: dict, headers=None, explain: str = '', path: str = ''): - if 'content-type' not in headers: - headers['content-type'] = DEFAULT_SPARQL_CONTENT_TYPE + def do_sparql_request( + self, data: dict, headers=None, explain: str = "", path: str = "" + ): + if "content-type" not in headers: + headers["content-type"] = DEFAULT_SPARQL_CONTENT_TYPE - if explain != '': - data['explain'] = explain + if explain != "": + data["explain"] = explain - if path != '': - sparql_path = f'/{path}' - elif self.sparql_path != '': - sparql_path = f'/{self.sparql_path}' + if path != "": + sparql_path = f"/{path}" + elif self.sparql_path != "": + sparql_path = f"/{self.sparql_path}" elif self.is_neptune_domain(): - sparql_path = f'/{SPARQL_ACTION}' + sparql_path = f"/{SPARQL_ACTION}" else: - sparql_path = '' + sparql_path = "" - uri = f'{self._http_protocol}://{self.host}:{self.port}{sparql_path}' - req = self._prepare_request('POST', uri, data=data, headers=headers) + uri = f"{self._http_protocol}://{self.host}:{self.port}{sparql_path}" + req = self._prepare_request("POST", uri, data=data, headers=headers) res = self._http_session.send(req, verify=self.ssl_verify) return res - def sparql(self, query: str, headers=None, explain: str = '', path: str = '') -> requests.Response: + def sparql( + self, query: str, headers=None, explain: str = "", path: str = "" + ) -> requests.Response: if headers is None: headers = {} - s = SPARQLWrapper('') + s = SPARQLWrapper("") s.setQuery(query) query_type = s.queryType.upper() - if query_type in ['SELECT', 'CONSTRUCT', 'ASK', 'DESCRIBE']: + if query_type in ["SELECT", "CONSTRUCT", "ASK", "DESCRIBE"]: return self.sparql_query(query, headers, explain, path=path) else: return self.sparql_update(query, headers, explain, path=path) # TODO: enum/constants for supported types - def sparql_explain(self, query: str, explain: str = 'dynamic', output_format: str = 'text/html', - headers=None, path: str = '') -> requests.Response: + def sparql_explain( + self, + query: str, + explain: str = "dynamic", + output_format: str = "text/html", + headers=None, + path: str = "", + ) -> requests.Response: if headers is None: headers = {} - if 'Accept' not in headers: - headers['Accept'] = output_format + if "Accept" not in headers: + headers["Accept"] = output_format return self.sparql(query, headers, explain, path=path) - def sparql_status(self, query_id: str = ''): - return self._query_status('sparql', query_id=query_id) + def sparql_status(self, query_id: str = ""): + return self._query_status("sparql", query_id=query_id) def sparql_cancel(self, query_id: str, silent: bool = False): - if type(query_id) is not str or query_id == '': - raise ValueError('query_id must be a non-empty string') - return self._query_status('sparql', query_id=query_id, silent=silent, cancelQuery=True) + if type(query_id) is not str or query_id == "": + raise ValueError("query_id must be a non-empty string") + return self._query_status( + "sparql", query_id=query_id, silent=silent, cancelQuery=True + ) def get_gremlin_connection(self, transport_kwargs) -> client.Client: nest_asyncio.apply() - ws_url = f'{self.get_uri_with_port(use_websocket=True)}/gremlin' - request = self._prepare_request('GET', ws_url) - traversal_source = 'g' if self.is_neptune_domain() else self.gremlin_traversal_source - return client.Client(ws_url, traversal_source, username=self.gremlin_username, - password=self.gremlin_password, message_serializer=self.gremlin_serializer, - headers=dict(request.headers), **transport_kwargs) + ws_url = f"{self.get_uri_with_port(use_websocket=True)}/gremlin" + request = self._prepare_request("GET", ws_url) + traversal_source = ( + "g" if self.is_neptune_domain() else self.gremlin_traversal_source + ) + return client.Client( + ws_url, + traversal_source, + username=self.gremlin_username, + password=self.gremlin_password, + message_serializer=self.gremlin_serializer, + headers=dict(request.headers), + **transport_kwargs, + ) def gremlin_query(self, query, transport_args=None, bindings=None): if transport_args is None: @@ -284,11 +367,15 @@ def gremlin_query(self, query, transport_args=None, bindings=None): return results except Exception as e: if isinstance(e, GremlinServerError): - source_err = re.compile('The traversal source \\[.] for alias \\[.] is not configured on the server\\.') + source_err = re.compile( + "The traversal source \\[.] for alias \\[.] is not configured on the server\\." + ) if e.status_code == 499 and source_err.search(str(e)): - print("Error returned by the Gremlin Server for the traversal_source specified in notebook " - "configuration. Please ensure that your graph database endpoint supports re-naming of " - "GraphTraversalSource from the default of 'g' in Gremlin Server.") + print( + "Error returned by the Gremlin Server for the traversal_source specified in notebook " + "configuration. Please ensure that your graph database endpoint supports re-naming of " + "GraphTraversalSource from the default of 'g' in Gremlin Server." + ) c.close() raise e @@ -296,77 +383,81 @@ def gremlin_http_query(self, query, headers=None) -> requests.Response: if headers is None: headers = {} - uri = f'{self.get_uri_with_port()}/gremlin' - data = {'gremlin': query} - req = self._prepare_request('POST', uri, data=json.dumps(data), headers=headers) + uri = f"{self.get_uri_with_port()}/gremlin" + data = {"gremlin": query} + req = self._prepare_request("POST", uri, data=json.dumps(data), headers=headers) res = self._http_session.send(req, verify=self.ssl_verify) return res - def gremlin_status(self, query_id: str = '', include_waiting: bool = False): + def gremlin_status(self, query_id: str = "", include_waiting: bool = False): kwargs = {} if include_waiting: - kwargs['includeWaiting'] = True - return self._query_status('gremlin', query_id=query_id, **kwargs) + kwargs["includeWaiting"] = True + return self._query_status("gremlin", query_id=query_id, **kwargs) def gremlin_cancel(self, query_id: str): - if type(query_id) is not str or query_id == '': - raise ValueError('query_id must be a non-empty string') - return self._query_status('gremlin', query_id=query_id, cancelQuery=True) + if type(query_id) is not str or query_id == "": + raise ValueError("query_id must be a non-empty string") + return self._query_status("gremlin", query_id=query_id, cancelQuery=True) def gremlin_explain(self, query: str, args={}) -> requests.Response: - return self._gremlin_query_plan(query=query, plan_type='explain', args=args) + return self._gremlin_query_plan(query=query, plan_type="explain", args=args) def gremlin_profile(self, query: str, args={}) -> requests.Response: - return self._gremlin_query_plan(query=query, plan_type='profile', args=args) - - def _gremlin_query_plan(self, query: str, plan_type: str, args: dict, ) -> requests.Response: - url = f'{self._http_protocol}://{self.host}:{self.port}/gremlin/{plan_type}' - data = {'gremlin': query} + return self._gremlin_query_plan(query=query, plan_type="profile", args=args) + + def _gremlin_query_plan( + self, + query: str, + plan_type: str, + args: dict, + ) -> requests.Response: + url = f"{self._http_protocol}://{self.host}:{self.port}/gremlin/{plan_type}" + data = {"gremlin": query} if args: for param, value in args.items(): data[param] = value - req = self._prepare_request('POST', url, data=json.dumps(data)) + req = self._prepare_request("POST", url, data=json.dumps(data)) res = self._http_session.send(req, verify=self.ssl_verify) return res - def opencypher_http(self, query: str, headers: dict = None, explain: str = None, - query_params: dict = None) -> requests.Response: + def opencypher_http( + self, + query: str, + headers: dict = None, + explain: str = None, + query_params: dict = None, + ) -> requests.Response: if headers is None: headers = {} - url = f'{self._http_protocol}://{self.host}:{self.port}/' + url = f"{self._http_protocol}://{self.host}:{self.port}/" if self.is_neptune_domain(): - if 'content-type' not in headers: - headers['content-type'] = 'application/x-www-form-urlencoded' - url += 'openCypher' - data = { - 'query': query - } + if "content-type" not in headers: + headers["content-type"] = "application/x-www-form-urlencoded" + url += "openCypher" + data = {"query": query} if explain: - data['explain'] = explain - headers['Accept'] = "text/html" + data["explain"] = explain + headers["Accept"] = "text/html" if query_params: - data['parameters'] = str(query_params).replace("'", '"') # '{"AUS_code":"AUS","WLG_code":"WLG"}' + data["parameters"] = str(query_params).replace( + "'", '"' + ) # '{"AUS_code":"AUS","WLG_code":"WLG"}' else: - url += 'db/neo4j/tx/commit' - headers['content-type'] = 'application/json' - headers['Accept'] = 'application/vnd.neo4j.jolt+json-seq' - - data_dict = { - "statements": [ - { - "statement": query - } - ] - } + url += "db/neo4j/tx/commit" + headers["content-type"] = "application/json" + headers["Accept"] = "application/vnd.neo4j.jolt+json-seq" + + data_dict = {"statements": [{"statement": query}]} data = json.dumps(data_dict) if self.neo4j_auth: user_and_pass = self.neo4j_username + ":" + self.neo4j_password user_and_pass_base64 = b64encode(user_and_pass.encode()) - headers['authorization'] = user_and_pass_base64 + headers["authorization"] = user_and_pass_base64 - req = self._prepare_request('POST', url, data=data, headers=headers) + req = self._prepare_request("POST", url, data=data, headers=headers) res = self._http_session.send(req, verify=self.ssl_verify) return res @@ -377,26 +468,30 @@ def opencyper_bolt(self, query: str, **kwargs): res = session.run(query, kwargs) data = res.data() except AuthError: - print("Neo4J Bolt request failed with an authentication error. Please ensure that the 'neo4j' section " - "of your %graph_notebook_config contains the correct credentials and auth setting.") + print( + "The Bolt request failed with an authentication error. Please ensure that the 'neo4j' or 'memgraph' section " + "of your %graph_notebook_config contains the correct credentials and auth setting." + ) data = [] driver.close() return data - def opencypher_status(self, query_id: str = '', include_waiting: bool = False): + def opencypher_status(self, query_id: str = "", include_waiting: bool = False): kwargs = {} if include_waiting: - kwargs['includeWaiting'] = True - return self._query_status('openCypher', query_id=query_id, **kwargs) + kwargs["includeWaiting"] = True + return self._query_status("openCypher", query_id=query_id, **kwargs) def opencypher_cancel(self, query_id, silent: bool = False): - if type(query_id) is not str or query_id == '': - raise ValueError('query_id must be a non-empty string') + if type(query_id) is not str or query_id == "": + raise ValueError("query_id must be a non-empty string") - return self._query_status('openCypher', query_id=query_id, cancelQuery=True, silent=silent) + return self._query_status( + "openCypher", query_id=query_id, cancelQuery=True, silent=silent + ) def get_opencypher_driver(self): - url = f'bolt://{self.host}:{self.port}' + url = f"bolt://{self.host}:{self.port}" if self.is_neptune_domain(): if self._session and self.iam_enabled: @@ -405,7 +500,9 @@ def get_opencypher_driver(self): status_res.raise_for_status() status_res_json = status_res.json() engine_version_raw = status_res_json["dbEngineVersion"] - engine_version = int(engine_version_raw.rsplit('.', 1)[0].replace('.', '')) + engine_version = int( + engine_version_raw.rsplit(".", 1)[0].replace(".", "") + ) if engine_version >= 1200: url += "/opencypher" @@ -413,11 +510,14 @@ def get_opencypher_driver(self): frozen_creds = credentials.get_frozen_credentials() auth_final = NeptuneBoltAuthToken(frozen_creds, self.region, url) else: - user = 'username' + user = "username" password = DEFAULT_NEO4J_PASSWORD auth_final = (user, password) else: - if self.neo4j_auth: + # user changed default Memgraph auth to True + if self.memgraph_auth: + auth_final = (self.memgraph_username, self.memgraph_password) + elif self.neo4j_auth: auth_final = (self.neo4j_username, self.neo4j_password) else: auth_final = None @@ -429,380 +529,473 @@ def stream(self, url, **kwargs) -> requests.Response: params = {} for k, v in kwargs.items(): params[k] = v - req = self._prepare_request('GET', url, params=params,data='') + req = self._prepare_request("GET", url, params=params, data="") res = self._http_session.send(req, verify=self.ssl_verify) return res.json() def status(self) -> requests.Response: - url = f'{self._http_protocol}://{self.host}:{self.port}/status' - req = self._prepare_request('GET', url, data='') + url = f"{self._http_protocol}://{self.host}:{self.port}/status" + req = self._prepare_request("GET", url, data="") res = self._http_session.send(req, verify=self.ssl_verify) return res - def load(self, source: str, source_format: str, iam_role_arn: str = None, **kwargs) -> requests.Response: + def load( + self, source: str, source_format: str, iam_role_arn: str = None, **kwargs + ) -> requests.Response: """ For a full list of allowed parameters, see aws documentation on the Neptune loader endpoint: https://docs.aws.amazon.com/neptune/latest/userguide/load-api-reference-load.html """ - payload = { - 'source': source, - 'format': source_format, - 'region': self.region - } + payload = {"source": source, "format": source_format, "region": self.region} if iam_role_arn: - payload['iamRoleArn'] = iam_role_arn + payload["iamRoleArn"] = iam_role_arn for key, value in kwargs.items(): payload[key] = value - url = f'{self._http_protocol}://{self.host}:{self.port}/loader' + url = f"{self._http_protocol}://{self.host}:{self.port}/loader" raw = json.dumps(payload) - req = self._prepare_request('POST', url, data=raw, headers={'content-type': 'application/json'}) + req = self._prepare_request( + "POST", url, data=raw, headers={"content-type": "application/json"} + ) res = self._http_session.send(req, verify=self.ssl_verify) return res - def load_status(self, load_id: str = '', **kwargs) -> requests.Response: + def load_status(self, load_id: str = "", **kwargs) -> requests.Response: params = {} for k, v in kwargs.items(): params[k] = v - if load_id != '': - params['loadId'] = load_id + if load_id != "": + params["loadId"] = load_id - url = f'{self._http_protocol}://{self.host}:{self.port}/loader' - req = self._prepare_request('GET', url, params=params) + url = f"{self._http_protocol}://{self.host}:{self.port}/loader" + req = self._prepare_request("GET", url, params=params) res = self._http_session.send(req, verify=self.ssl_verify) return res def cancel_load(self, load_id: str) -> requests.Response: - url = f'{self._http_protocol}://{self.host}:{self.port}/loader' - params = {'loadId': load_id} - req = self._prepare_request('DELETE', url, params=params) + url = f"{self._http_protocol}://{self.host}:{self.port}/loader" + params = {"loadId": load_id} + req = self._prepare_request("DELETE", url, params=params) res = self._http_session.send(req, verify=self.ssl_verify) return res def initiate_reset(self) -> requests.Response: - data = { - 'action': 'initiateDatabaseReset' - } - url = f'{self._http_protocol}://{self.host}:{self.port}/system' - req = self._prepare_request('POST', url, data=data) + data = {"action": "initiateDatabaseReset"} + url = f"{self._http_protocol}://{self.host}:{self.port}/system" + req = self._prepare_request("POST", url, data=data) res = self._http_session.send(req, verify=self.ssl_verify) return res def perform_reset(self, token: str) -> requests.Response: - data = { - 'action': 'performDatabaseReset', - 'token': token - } - url = f'{self._http_protocol}://{self.host}:{self.port}/system' - req = self._prepare_request('POST', url, data=data) + data = {"action": "performDatabaseReset", "token": token} + url = f"{self._http_protocol}://{self.host}:{self.port}/system" + req = self._prepare_request("POST", url, data=data) res = self._http_session.send(req, verify=self.ssl_verify) return res - def dataprocessing_start(self, s3_input_uri: str, s3_output_uri: str, **kwargs) -> requests.Response: + def dataprocessing_start( + self, s3_input_uri: str, s3_output_uri: str, **kwargs + ) -> requests.Response: data = { - 'inputDataS3Location': s3_input_uri, - 'processedDataS3Location': s3_output_uri, + "inputDataS3Location": s3_input_uri, + "processedDataS3Location": s3_output_uri, } for k, v in kwargs.items(): data[k] = v - url = f'{self._http_protocol}://{self.host}:{self.port}/ml/dataprocessing' - req = self._prepare_request('POST', url, data=json.dumps(data), headers={'content-type': 'application/json'}) + url = f"{self._http_protocol}://{self.host}:{self.port}/ml/dataprocessing" + req = self._prepare_request( + "POST", + url, + data=json.dumps(data), + headers={"content-type": "application/json"}, + ) res = self._http_session.send(req, verify=self.ssl_verify) return res - def dataprocessing_job_status(self, job_id: str, neptune_iam_role_arn: str = '') -> requests.Response: - url = f'{self._http_protocol}://{self.host}:{self.port}/ml/dataprocessing/{job_id}' + def dataprocessing_job_status( + self, job_id: str, neptune_iam_role_arn: str = "" + ) -> requests.Response: + url = f"{self._http_protocol}://{self.host}:{self.port}/ml/dataprocessing/{job_id}" data = {} - if neptune_iam_role_arn != '': - data['neptuneIamRoleArn'] = neptune_iam_role_arn - req = self._prepare_request('GET', url, params=data) + if neptune_iam_role_arn != "": + data["neptuneIamRoleArn"] = neptune_iam_role_arn + req = self._prepare_request("GET", url, params=data) res = self._http_session.send(req, verify=self.ssl_verify) return res - def dataprocessing_list(self, max_items: int = 10, neptune_iam_role_arn: str = '') -> requests.Response: - url = f'{self._http_protocol}://{self.host}:{self.port}/ml/dataprocessing' - data = { - 'maxItems': max_items - } + def dataprocessing_list( + self, max_items: int = 10, neptune_iam_role_arn: str = "" + ) -> requests.Response: + url = f"{self._http_protocol}://{self.host}:{self.port}/ml/dataprocessing" + data = {"maxItems": max_items} - if neptune_iam_role_arn != '': - data['neptuneIamRoleArn'] = neptune_iam_role_arn - req = self._prepare_request('GET', url, params=data) + if neptune_iam_role_arn != "": + data["neptuneIamRoleArn"] = neptune_iam_role_arn + req = self._prepare_request("GET", url, params=data) res = self._http_session.send(req, verify=self.ssl_verify) return res - def dataprocessing_stop(self, job_id: str, clean=False, neptune_iam_role_arn: str = '') -> requests.Response: - url = f'{self._http_protocol}://{self.host}:{self.port}/ml/dataprocessing/{job_id}' - data = { - 'clean': clean - } - if neptune_iam_role_arn != '': - data['neptuneIamRoleArn'] = neptune_iam_role_arn + def dataprocessing_stop( + self, job_id: str, clean=False, neptune_iam_role_arn: str = "" + ) -> requests.Response: + url = f"{self._http_protocol}://{self.host}:{self.port}/ml/dataprocessing/{job_id}" + data = {"clean": clean} + if neptune_iam_role_arn != "": + data["neptuneIamRoleArn"] = neptune_iam_role_arn - req = self._prepare_request('DELETE', url, params=data) + req = self._prepare_request("DELETE", url, params=data) res = self._http_session.send(req, verify=self.ssl_verify) return res - def modeltraining_start(self, data_processing_job_id: str, train_model_s3_location: str, - max_hpo_number_of_training_jobs: int, max_hpo_parallel_training_jobs: int, - **kwargs) -> requests.Response: + def modeltraining_start( + self, + data_processing_job_id: str, + train_model_s3_location: str, + max_hpo_number_of_training_jobs: int, + max_hpo_parallel_training_jobs: int, + **kwargs, + ) -> requests.Response: """ for a full list of supported parameters, see: https://docs.aws.amazon.com/neptune/latest/userguide/machine-learning-api-modeltraining.html """ data = { - 'dataProcessingJobId': data_processing_job_id, - 'trainModelS3Location': train_model_s3_location, - 'maxHPONumberOfTrainingJobs': max_hpo_number_of_training_jobs, - 'maxHPOParallelTrainingJobs': max_hpo_parallel_training_jobs + "dataProcessingJobId": data_processing_job_id, + "trainModelS3Location": train_model_s3_location, + "maxHPONumberOfTrainingJobs": max_hpo_number_of_training_jobs, + "maxHPOParallelTrainingJobs": max_hpo_parallel_training_jobs, } for k, v in kwargs.items(): data[k] = v - url = f'{self._http_protocol}://{self.host}:{self.port}/ml/modeltraining' - req = self._prepare_request('POST', url, data=json.dumps(data), headers={'content-type': 'application/json'}) + url = f"{self._http_protocol}://{self.host}:{self.port}/ml/modeltraining" + req = self._prepare_request( + "POST", + url, + data=json.dumps(data), + headers={"content-type": "application/json"}, + ) res = self._http_session.send(req, verify=self.ssl_verify) return res - def modeltraining_list(self, max_items: int = 10, neptune_iam_role_arn: str = '') -> requests.Response: - data = { - 'maxItems': max_items - } + def modeltraining_list( + self, max_items: int = 10, neptune_iam_role_arn: str = "" + ) -> requests.Response: + data = {"maxItems": max_items} - if neptune_iam_role_arn != '': - data['neptuneIamRoleArn'] = neptune_iam_role_arn + if neptune_iam_role_arn != "": + data["neptuneIamRoleArn"] = neptune_iam_role_arn - url = f'{self._http_protocol}://{self.host}:{self.port}/ml/modeltraining' - req = self._prepare_request('GET', url, params=data) + url = f"{self._http_protocol}://{self.host}:{self.port}/ml/modeltraining" + req = self._prepare_request("GET", url, params=data) res = self._http_session.send(req, verify=self.ssl_verify) return res - def modeltraining_job_status(self, training_job_id: str, neptune_iam_role_arn: str = '') -> requests.Response: - data = {} if neptune_iam_role_arn == '' else {'neptuneIamRoleArn': neptune_iam_role_arn} - url = f'{self._http_protocol}://{self.host}:{self.port}/ml/modeltraining/{training_job_id}' - req = self._prepare_request('GET', url, params=data) + def modeltraining_job_status( + self, training_job_id: str, neptune_iam_role_arn: str = "" + ) -> requests.Response: + data = ( + {} + if neptune_iam_role_arn == "" + else {"neptuneIamRoleArn": neptune_iam_role_arn} + ) + url = f"{self._http_protocol}://{self.host}:{self.port}/ml/modeltraining/{training_job_id}" + req = self._prepare_request("GET", url, params=data) res = self._http_session.send(req, verify=self.ssl_verify) return res - def modeltraining_stop(self, training_job_id: str, neptune_iam_role_arn: str = '', - clean: bool = False) -> requests.Response: + def modeltraining_stop( + self, training_job_id: str, neptune_iam_role_arn: str = "", clean: bool = False + ) -> requests.Response: data = { - 'clean': "TRUE" if clean else "FALSE", + "clean": "TRUE" if clean else "FALSE", } - if neptune_iam_role_arn != '': - data['neptuneIamRoleArn'] = neptune_iam_role_arn + if neptune_iam_role_arn != "": + data["neptuneIamRoleArn"] = neptune_iam_role_arn - url = f'{self._http_protocol}://{self.host}:{self.port}/ml/modeltraining/{training_job_id}' - req = self._prepare_request('DELETE', url, params=data) + url = f"{self._http_protocol}://{self.host}:{self.port}/ml/modeltraining/{training_job_id}" + req = self._prepare_request("DELETE", url, params=data) res = self._http_session.send(req, verify=self.ssl_verify) return res - def modeltransform_create(self, output_s3_location: str, dataprocessing_job_id: str = '', - modeltraining_job_id: str = '', training_job_name: str = '', - **kwargs) -> requests.Response: - logger.debug("modeltransform_create initiated with params:" - f"output_s3_location: {output_s3_location}\n" - f"dataprocessing_job_id: {dataprocessing_job_id}\n" - f"modeltraining_job_id: {modeltraining_job_id}\n" - f"training_job_name: {training_job_name}\n" - f"kwargs: {kwargs}") - data = { - 'modelTransformOutputS3Location': output_s3_location - } + def modeltransform_create( + self, + output_s3_location: str, + dataprocessing_job_id: str = "", + modeltraining_job_id: str = "", + training_job_name: str = "", + **kwargs, + ) -> requests.Response: + logger.debug( + "modeltransform_create initiated with params:" + f"output_s3_location: {output_s3_location}\n" + f"dataprocessing_job_id: {dataprocessing_job_id}\n" + f"modeltraining_job_id: {modeltraining_job_id}\n" + f"training_job_name: {training_job_name}\n" + f"kwargs: {kwargs}" + ) + data = {"modelTransformOutputS3Location": output_s3_location} if not dataprocessing_job_id and not modeltraining_job_id and training_job_name: - data['trainingJobName'] = training_job_name + data["trainingJobName"] = training_job_name elif dataprocessing_job_id and modeltraining_job_id and not training_job_name: - data['dataProcessingJobId'] = dataprocessing_job_id - data['mlModelTrainingJobId'] = modeltraining_job_id + data["dataProcessingJobId"] = dataprocessing_job_id + data["mlModelTrainingJobId"] = modeltraining_job_id else: raise ValueError( - 'Invalid input. Must only specify either dataprocessing_job_id and modeltraining_job_id or only ' - 'training_job_name') + "Invalid input. Must only specify either dataprocessing_job_id and modeltraining_job_id or only " + "training_job_name" + ) for k, v in kwargs.items(): data[k] = v - headers = { - 'content-type': 'application/json' - } + headers = {"content-type": "application/json"} - url = f'{self._http_protocol}://{self.host}:{self.port}/ml/modeltransform' - req = self._prepare_request('POST', url, data=json.dumps(data), headers=headers) + url = f"{self._http_protocol}://{self.host}:{self.port}/ml/modeltransform" + req = self._prepare_request("POST", url, data=json.dumps(data), headers=headers) res = self._http_session.send(req, verify=self.ssl_verify) return res - def modeltransform_status(self, job_id: str, iam_role: str = '') -> requests.Response: + def modeltransform_status( + self, job_id: str, iam_role: str = "" + ) -> requests.Response: data = {} - if iam_role != '': - data['neptuneIamRoleArn'] = iam_role + if iam_role != "": + data["neptuneIamRoleArn"] = iam_role - url = f'{self._http_protocol}://{self.host}:{self.port}/ml/modeltransform/{job_id}' - req = self._prepare_request('GET', url, params=data) + url = f"{self._http_protocol}://{self.host}:{self.port}/ml/modeltransform/{job_id}" + req = self._prepare_request("GET", url, params=data) res = self._http_session.send(req, verify=self.ssl_verify) return res - def modeltransform_list(self, iam_role: str = '', max_items: int = 10) -> requests.Response: - data = { - 'maxItems': max_items - } + def modeltransform_list( + self, iam_role: str = "", max_items: int = 10 + ) -> requests.Response: + data = {"maxItems": max_items} - if iam_role != '': - data['neptuneIamRoleArn'] = iam_role + if iam_role != "": + data["neptuneIamRoleArn"] = iam_role - url = f'{self._http_protocol}://{self.host}:{self.port}/ml/modeltransform' - req = self._prepare_request('GET', url, params=data) + url = f"{self._http_protocol}://{self.host}:{self.port}/ml/modeltransform" + req = self._prepare_request("GET", url, params=data) res = self._http_session.send(req, verify=self.ssl_verify) return res - def modeltransform_stop(self, job_id: str, iam_role: str = '', clean: bool = False) -> requests.Response: - data = { - 'clean': 'TRUE' if clean else 'FALSE' - } - if iam_role != '': - data['neptuneIamRoleArn'] = iam_role + def modeltransform_stop( + self, job_id: str, iam_role: str = "", clean: bool = False + ) -> requests.Response: + data = {"clean": "TRUE" if clean else "FALSE"} + if iam_role != "": + data["neptuneIamRoleArn"] = iam_role - url = f'{self._http_protocol}://{self.host}:{self.port}/ml/modeltransform/{job_id}' - req = self._prepare_request('DELETE', url, params=data) + url = f"{self._http_protocol}://{self.host}:{self.port}/ml/modeltransform/{job_id}" + req = self._prepare_request("DELETE", url, params=data) res = self._http_session.send(req, verify=self.ssl_verify) return res - def endpoints_create(self, model_training_job_id: str = '', model_transform_job_id: str = '', - **kwargs) -> requests.Response: + def endpoints_create( + self, + model_training_job_id: str = "", + model_transform_job_id: str = "", + **kwargs, + ) -> requests.Response: data = {} if model_training_job_id and not model_transform_job_id: - data['mlModelTrainingJobId'] = model_training_job_id + data["mlModelTrainingJobId"] = model_training_job_id elif model_transform_job_id and not model_training_job_id: - data['mlModelTransformJobId'] = model_transform_job_id + data["mlModelTransformJobId"] = model_transform_job_id else: - raise ValueError('Invalid input. Must either specify model_training_job_id or model_transform_job_id, ' - 'and not both.') + raise ValueError( + "Invalid input. Must either specify model_training_job_id or model_transform_job_id, " + "and not both." + ) for k, v in kwargs.items(): data[k] = v - url = f'{self._http_protocol}://{self.host}:{self.port}/ml/endpoints' - req = self._prepare_request('POST', url, data=json.dumps(data), headers={'content-type': 'application/json'}) + url = f"{self._http_protocol}://{self.host}:{self.port}/ml/endpoints" + req = self._prepare_request( + "POST", + url, + data=json.dumps(data), + headers={"content-type": "application/json"}, + ) res = self._http_session.send(req, verify=self.ssl_verify) return res - def endpoints_status(self, endpoint_id: str, neptune_iam_role_arn: str = '') -> requests.Response: - data = {} if neptune_iam_role_arn == '' else {'neptuneIamRoleArn': neptune_iam_role_arn} - url = f'{self._http_protocol}://{self.host}:{self.port}/ml/endpoints/{endpoint_id}' - req = self._prepare_request('GET', url, params=data) + def endpoints_status( + self, endpoint_id: str, neptune_iam_role_arn: str = "" + ) -> requests.Response: + data = ( + {} + if neptune_iam_role_arn == "" + else {"neptuneIamRoleArn": neptune_iam_role_arn} + ) + url = f"{self._http_protocol}://{self.host}:{self.port}/ml/endpoints/{endpoint_id}" + req = self._prepare_request("GET", url, params=data) res = self._http_session.send(req, verify=self.ssl_verify) return res - def endpoints_delete(self, endpoint_id: str, neptune_iam_role_arn: str = '') -> requests.Response: - data = {} if neptune_iam_role_arn == '' else {'neptuneIamRoleArn': neptune_iam_role_arn} - url = f'{self._http_protocol}://{self.host}:{self.port}/ml/endpoints/{endpoint_id}' - req = self._prepare_request('DELETE', url, params=data) + def endpoints_delete( + self, endpoint_id: str, neptune_iam_role_arn: str = "" + ) -> requests.Response: + data = ( + {} + if neptune_iam_role_arn == "" + else {"neptuneIamRoleArn": neptune_iam_role_arn} + ) + url = f"{self._http_protocol}://{self.host}:{self.port}/ml/endpoints/{endpoint_id}" + req = self._prepare_request("DELETE", url, params=data) res = self._http_session.send(req, verify=self.ssl_verify) return res - def endpoints(self, max_items: int = 10, neptune_iam_role_arn: str = '') -> requests.Response: - data = { - 'maxItems': max_items - } - if neptune_iam_role_arn != '': - data['neptuneIamRoleArn'] = neptune_iam_role_arn + def endpoints( + self, max_items: int = 10, neptune_iam_role_arn: str = "" + ) -> requests.Response: + data = {"maxItems": max_items} + if neptune_iam_role_arn != "": + data["neptuneIamRoleArn"] = neptune_iam_role_arn - url = f'{self._http_protocol}://{self.host}:{self.port}/ml/endpoints' - req = self._prepare_request('GET', url, params=data) + url = f"{self._http_protocol}://{self.host}:{self.port}/ml/endpoints" + req = self._prepare_request("GET", url, params=data) res = self._http_session.send(req, verify=self.ssl_verify) return res def export(self, host: str, params: dict, ssl: bool = True) -> requests.Response: - protocol = 'https' if ssl else 'http' - url = f'{protocol}://{host}/{EXPORT_ACTION}' - req = self._prepare_request('POST', url, data=json.dumps(params), service="execute-api") + protocol = "https" if ssl else "http" + url = f"{protocol}://{host}/{EXPORT_ACTION}" + req = self._prepare_request( + "POST", url, data=json.dumps(params), service="execute-api" + ) res = self._http_session.send(req, verify=self.ssl_verify) return res def export_status(self, host, job_id, ssl: bool = True) -> requests.Response: - protocol = 'https' if ssl else 'http' - url = f'{protocol}://{host}/{EXPORT_ACTION}/{job_id}' - req = self._prepare_request('GET', url, service="execute-api") + protocol = "https" if ssl else "http" + url = f"{protocol}://{host}/{EXPORT_ACTION}/{job_id}" + req = self._prepare_request("GET", url, service="execute-api") res = self._http_session.send(req, verify=self.ssl_verify) return res - def _query_status(self, language: str, *, query_id: str = '', **kwargs) -> requests.Response: + def _query_status( + self, language: str, *, query_id: str = "", **kwargs + ) -> requests.Response: data = {} - if query_id != '': - data['queryId'] = query_id + if query_id != "": + data["queryId"] = query_id for k, v in kwargs.items(): data[k] = v - headers = { - 'Content-Type': 'application/x-www-form-urlencoded' - } - url = f'{self._http_protocol}://{self.host}:{self.port}/{language}/status' - req = self._prepare_request('POST', url, data=data, headers=headers) + headers = {"Content-Type": "application/x-www-form-urlencoded"} + url = f"{self._http_protocol}://{self.host}:{self.port}/{language}/status" + req = self._prepare_request("POST", url, data=data, headers=headers) res = self._http_session.send(req, verify=self.ssl_verify) return res - def statistics(self, language: str, summary: bool = False, mode: str = '') -> requests.Response: - headers = { - 'Accept': 'application/json' - } + def statistics( + self, language: str, summary: bool = False, mode: str = "" + ) -> requests.Response: + headers = {"Accept": "application/json"} if language in ["gremlin", "oc", "opencypher"]: language = "pg" elif language == "sparql": language = "rdf" - url = f'{self._http_protocol}://{self.host}:{self.port}/{language}/statistics' - data = {'mode': mode} + url = f"{self._http_protocol}://{self.host}:{self.port}/{language}/statistics" + data = {"mode": mode} if summary: - summary_url = url + '/summary' + summary_url = url + "/summary" if mode: - summary_mode_param = '?mode=' + mode + summary_mode_param = "?mode=" + mode summary_url += summary_mode_param - req = self._prepare_request('GET', summary_url, headers=headers) + req = self._prepare_request("GET", summary_url, headers=headers) else: - if mode in ['', 'status']: - req = self._prepare_request('GET', url, headers=headers) - elif mode == 'delete': - req = self._prepare_request('DELETE', url, headers=headers) + if mode in ["", "status"]: + req = self._prepare_request("GET", url, headers=headers) + elif mode == "delete": + req = self._prepare_request("DELETE", url, headers=headers) else: - req = self._prepare_request('POST', url, data=json.dumps(data), headers=headers) + req = self._prepare_request( + "POST", url, data=json.dumps(data), headers=headers + ) res = self._http_session.send(req) return res - def _prepare_request(self, method, url, *, data=None, params=None, headers=None, service=NEPTUNE_SERVICE_NAME): + def _prepare_request( + self, + method, + url, + *, + data=None, + params=None, + headers=None, + service=NEPTUNE_SERVICE_NAME, + ): self._ensure_http_session() - if self.proxy_host != '': + if self.proxy_host != "": headers = {} if headers is None else headers headers["Host"] = self.target_host - request = requests.Request(method=method, url=url, data=data, params=params, headers=headers, auth=self._auth) + request = requests.Request( + method=method, + url=url, + data=data, + params=params, + headers=headers, + auth=self._auth, + ) if self._session is not None: - aws_request = self._get_aws_request(method=method, url=url, data=data, params=params, headers=headers, - service=service) + aws_request = self._get_aws_request( + method=method, + url=url, + data=data, + params=params, + headers=headers, + service=service, + ) request.headers = dict(aws_request.headers) return request.prepare() - def _get_aws_request(self, method, url, *, data=None, params=None, headers=None, service=NEPTUNE_SERVICE_NAME): - req = AWSRequest(method=method, url=url, data=data, params=params, headers=headers) + def _get_aws_request( + self, + method, + url, + *, + data=None, + params=None, + headers=None, + service=NEPTUNE_SERVICE_NAME, + ): + req = AWSRequest( + method=method, url=url, data=data, params=params, headers=headers + ) if self.iam_enabled: credentials = self._session.get_credentials() try: frozen_creds = credentials.get_frozen_credentials() except AttributeError: - print("Could not find valid IAM credentials in any the following locations:\n") - print("env, assume-role, assume-role-with-web-identity, sso, shared-credential-file, custom-process, " - "config-file, ec2-credentials-file, boto-config, container-role, iam-role\n") - print("Go to https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html for more " - "details on configuring your IAM credentials.") + print( + "Could not find valid IAM credentials in any the following locations:\n" + ) + print( + "env, assume-role, assume-role-with-web-identity, sso, shared-credential-file, custom-process, " + "config-file, ec2-credentials-file, boto-config, container-role, iam-role\n" + ) + print( + "Go to https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html for more " + "details on configuring your IAM credentials." + ) return req SigV4Auth(frozen_creds, service, self.region).add_auth(req) prepared_iam_req = req.prepare() @@ -834,63 +1027,69 @@ def __init__(self, args: dict = None): self.args = args def with_host(self, host: str): - self.args['host'] = host + self.args["host"] = host return ClientBuilder(self.args) def with_port(self, port: int): - self.args['port'] = port + self.args["port"] = port return ClientBuilder(self.args) def with_sparql_path(self, path: str): - self.args['sparql_path'] = path + self.args["sparql_path"] = path return ClientBuilder(self.args) def with_gremlin_traversal_source(self, traversal_source: str): - self.args['gremlin_traversal_source'] = traversal_source + self.args["gremlin_traversal_source"] = traversal_source return ClientBuilder(self.args) - + def with_gremlin_login(self, username: str, password: str): - self.args['gremlin_username'] = username - self.args['gremlin_password'] = password + self.args["gremlin_username"] = username + self.args["gremlin_password"] = password return ClientBuilder(self.args) def with_gremlin_serializer(self, message_serializer: str): - self.args['gremlin_serializer'] = message_serializer + self.args["gremlin_serializer"] = message_serializer return ClientBuilder(self.args) - + def with_neo4j_login(self, username: str, password: str, auth: bool, database: str): - self.args['neo4j_username'] = username - self.args['neo4j_password'] = password - self.args['neo4j_auth'] = auth - self.args['neo4j_database'] = database + self.args["neo4j_username"] = username + self.args["neo4j_password"] = password + self.args["neo4j_auth"] = auth + self.args["neo4j_database"] = database + return ClientBuilder(self.args) + + def with_memgraph_login(self, username: str, password: str, auth: bool): + self.args["memgraph_username"] = username + self.args["memgraph_password"] = password + self.args["memgraph_auth"] = auth return ClientBuilder(self.args) def with_tls(self, tls: bool): - self.args['ssl'] = tls + self.args["ssl"] = tls return ClientBuilder(self.args) def with_ssl_verify(self, ssl_verify: bool): - self.args['ssl_verify'] = ssl_verify + self.args["ssl_verify"] = ssl_verify return ClientBuilder(self.args) def with_region(self, region: str): - self.args['region'] = region + self.args["region"] = region return ClientBuilder(self.args) def with_iam(self, session: Session): - self.args['session'] = session + self.args["session"] = session return ClientBuilder(self.args) def with_proxy_host(self, host: str): - self.args['proxy_host'] = host + self.args["proxy_host"] = host return ClientBuilder(self.args) def with_proxy_port(self, proxy_port: int): - self.args['proxy_port'] = proxy_port + self.args["proxy_port"] = proxy_port return ClientBuilder(self.args) def with_custom_neptune_hosts(self, neptune_hosts: list): - self.args['neptune_hosts'] = neptune_hosts + self.args["neptune_hosts"] = neptune_hosts return ClientBuilder(self.args) def build(self) -> Client: diff --git a/src/graph_notebook/notebooks/01-Getting-Started/06-Supply-Chain-Analysis-with-Memgraph.ipynb b/src/graph_notebook/notebooks/01-Getting-Started/06-Supply-Chain-Analysis-with-Memgraph.ipynb new file mode 100644 index 00000000..a79c7516 --- /dev/null +++ b/src/graph_notebook/notebooks/01-Getting-Started/06-Supply-Chain-Analysis-with-Memgraph.ipynb @@ -0,0 +1,655 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Supply Chain Analysis with Memgraph\n", + "\n", + "## Table of contents \n", + "1. [Introduction](#introduction)\n", + "2. [Connect to Memgraph](#connect-to-memgraph)\n", + "3. [Create the dataset](#create-the-dataset)\n", + "4. [Supply Chain Analysis](#supply-chain-analysis)\n", + " - [Acquiring critical hubs in the network with betweenness centrality](#critical-hubs)\n", + " - [Get ingredients provided by the supplier](#get-ingredients)\n", + " - [Pathfinding for necessary ingredients](#pathfinding)\n", + " - [Checking dependencies of the product with ancestors](#ancestors)\n", + " - [Ancestors graph](#ancestors-graph)\n", + " - [Checking possible products for production with descendants](#descendants)\n", + " - [Descendants graph](#descendants-graph)\n", + " - [Getting the order of execution with topological sort](#topological-sort)\n", + "5. [Conclusion](#conclusion)\n", + "\n", + "## 1. Introduction\n", + "\n", + "In supply chain management, a network of process steps is drawn to minimize product delivery time from production to shipping. Up to this day, optimizations in process steps are mostly carried out by staff members, who can be prone to errors and under-optimized solutions. Moreover, it takes them a reasonable amount of time to design an optimal schedule when they could have been utilized for processes requiring more expertise and knowledge with an automated process scheduling the supply chain.\n", + "\n", + "In this notebook, you'll learn how to start Memgraph, connect to it and run Cypher queries to explore the supply chain and learn more about the power of graphs in that domain. \n", + "\n", + "## 2. Connect to Memgraph\n", + "\n", + "[Memgraph](https://memgraph.com/) is an open-source in-memory graph database built for highly performant and advanced analytical insights. Memgraph is Neo4j Bolt protocol compatible and uses the standardized Cypher query language. \n", + "\n", + "For a quick start, run the following command in your terminal to start the Memgraph Platform in a Docker container: \n", + "\n", + "```\n", + "docker run -it -p 7687:7687 -p 7444:7444 -p 3000:3000 -e MEMGRAPH=\"--bolt-server-name-for-init=Neo4j/\" memgraph/memgraph-platform\n", + "```\n", + "\n", + "The above command starts the Memgraph database, MAGE (graph algorithms library) and Memgraph Lab (visual user interface). For additional instructions on setting up and running Memgraph locally, refer to the [Memgraph documentation](https://memgraph.com/docs/memgraph/installation). Connection to the Graph Notebook works if the `--bolt-server-name-for-init` setting is modified. For more information on changing configuration settings, refer to our [how-to guide](https://memgraph.com/docs/memgraph/how-to-guides/config-logs).\n", + "\n", + "\n", + "After the local setup of Memgraph is complete, set the following configuration to connect from the Graph Notebook:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "set notebook config to:\n", + "{\n", + " \"host\": \"localhost\",\n", + " \"port\": 7687,\n", + " \"proxy_host\": \"\",\n", + " \"proxy_port\": 8182,\n", + " \"ssl\": false,\n", + " \"ssl_verify\": true,\n", + " \"sparql\": {\n", + " \"path\": \"\"\n", + " },\n", + " \"gremlin\": {\n", + " \"traversal_source\": \"g\",\n", + " \"username\": \"\",\n", + " \"password\": \"\",\n", + " \"message_serializer\": \"graphsonv3\"\n", + " },\n", + " \"neo4j\": {\n", + " \"username\": \"neo4j\",\n", + " \"password\": \"password\",\n", + " \"auth\": true,\n", + " \"database\": null\n", + " }\n", + "}\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%graph_notebook_config\n", + "{\n", + " \"host\": \"localhost\",\n", + " \"port\": 7687,\n", + " \"ssl\": false,\n", + " \"neo4j\": {\n", + " \"auth\": true,\n", + " \"database\": \"\"\n", + " }\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Create the dataset\n", + "\n", + "You can query Memgraph via Bolt protocol designed for efficient communication with graph databases. Memgraph supports versions 1, 4 and 5 of the protocol. Specify the `%%oc bolt` option when submitting queries to the Bolt endpoint.\n", + "\n", + "Before we analyze the dataset, we have to import it. The easiest way to do that with the `graph-notebook` is to run `CREATE` Cypher queries. Once you run the code cell below, the Memgraph database will be populated with a supply chain dataset. " + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1937e41be1fc42a5b1f7c64dd0c7527c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Tab(children=(Output(layout=Layout(max_height='600px', max_width='940px', overflow='scroll')), Output(layout=L…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%oc bolt\n", + "CREATE (sup1:Supplier {id: 1, name: \"Supplissimus\", centrality: 0.027920624240525559})\n", + "CREATE (sup2:Supplier {id: 2, name: \"Supplionis\", centrality: 0.002840909090909091})\n", + "CREATE (sup3:Supplier {id: 3, name: \"MegaSupplies\", centrality: 0.055822172619047615})\n", + "CREATE (sup4:Supplier {id: 4, name: \"Supplies4you\", centrality: 0})\n", + "CREATE (ing1:Ingredient {id: 1, name: \"Ingredient 1\", centrality: 0.0042365042365042358})\n", + "CREATE (ing2:Ingredient {id: 2, name: \"Ingredient 2\", centrality: 0.077438394705712787})\n", + "CREATE (ing3:Ingredient {id: 3, name: \"Ingredient 3\", centrality: 0.025363208468374868})\n", + "CREATE (ing4:Ingredient {id: 4, name: \"Ingredient 4\", centrality: 0.036831658149140731})\n", + "CREATE (ing5:Ingredient {id: 5, name: \"Ingredient 5\", centrality: 0.018939393939393933})\n", + "CREATE (ing6:Ingredient {id: 6, name: \"Ingredient 6\", centrality: 0.018939393939393933})\n", + "CREATE (ing7:Ingredient {id: 7, name: \"Ingredient 7\", centrality: 0.018939393939393933})\n", + "CREATE (ing8:Ingredient {id: 8, name: \"Ingredient 8\", centrality: 0.066602827149702143})\n", + "CREATE (ing9:Ingredient {id: 9, name: \"Ingredient 9\", centrality: 0.076719345469345446})\n", + "CREATE (ing10:Ingredient {id: 10, name: \"Ingredient 10\", centrality: 0.13523010455818119})\n", + "CREATE (pro1:Product {id: 1, name: \"Intermediate product 1\", centrality: 0.075849577597110474})\n", + "CREATE (pro2:Product {id: 2, name: \"Intermediate product 2\", centrality: 0.30307542895342809})\n", + "CREATE (pro3:Product {id: 3, name: \"Intermediate product 3\", centrality: 0.27450054057784318})\n", + "CREATE (pro4:Product {id: 4, name: \"Intermediate product 4\", centrality: 0.12564154013699291})\n", + "CREATE (pro5:Product {id: 5, name: \"Intermediate product 5\", centrality: 0.018604622671718259})\n", + "CREATE (pro6:FinalProduct:Product {id: 6, name: \"Final product 1\", centrality: 0.02814078282828282})\n", + "CREATE (pro7:FinalProduct:Product {id: 7, name: \"Final product 2\", centrality: 0.035353535353535366})\n", + "CREATE (pro8:FinalProduct:Product {id: 8, name: \"Final product 3\", centrality: 0.1539119291441273})\n", + "CREATE (shi1:Shipping {id: 1, name: \"Shipping point 1\", centrality: 0.0066761363636363633})\n", + "CREATE (shi2:Shipping {id: 2, name: \"Shipping point 2\", centrality: 0})\n", + "CREATE (rec1:Recipe {id: 1, name: \"Recipe for product 1\", centrality: 0.077470165525264201})\n", + "CREATE (rec2:Recipe {id: 2, name: \"Recipe for product 2\", centrality: 0.15612639008415902})\n", + "CREATE (rec3:Recipe {id: 3, name: \"Recipe for product 3\", centrality: 0.27750650680338179})\n", + "CREATE (rec4:Recipe {id: 4, name: \"Recipe for product 4\", centrality: 0.072996207394185345})\n", + "CREATE (rec5:Recipe {id: 5, name: \"Recipe for product 5\", centrality: 0.051091351458998513})\n", + "CREATE (rec6:Recipe {id: 6, name: \"Recipe for final product 1\", centrality: 0.23304036135039422})\n", + "CREATE (rec7:Recipe {id: 7, name: \"Recipe for final product 2\", centrality: 0.24386567715587651})\n", + "CREATE (rec8:Recipe {id: 8, name: \"Recipe for final product 3 - variant 1\", centrality: 0.088413170560519616})\n", + "CREATE (rec9:Recipe {id: 9, name: \"Recipe for final product 3 - variant 2\", centrality: 0.18098001437059097})\n", + "CREATE (rec10:Recipe {id: 10, name: \"Recipe for final product 3 - variant 3\", centrality: 0.082068494800692962})\n", + "CREATE (sup1)-[:SUPPLIES]->(ing1)\n", + "CREATE (sup1)-[:SUPPLIES]->(ing2)\n", + "CREATE (sup1)-[:SUPPLIES]->(ing3)\n", + "CREATE (sup1)-[:SUPPLIES]->(ing4)\n", + "CREATE (sup2)-[:SUPPLIES]->(ing5)\n", + "CREATE (sup2)-[:SUPPLIES]->(ing6)\n", + "CREATE (sup2)-[:SUPPLIES]->(ing7)\n", + "CREATE (sup3)-[:SUPPLIES]->(ing8)\n", + "CREATE (sup3)-[:SUPPLIES]->(ing9)\n", + "CREATE (sup4)-[:SUPPLIES]->(ing10)\n", + "CREATE (pro1)-[:FORMS {quantity: 30}]->(rec1)\n", + "CREATE (pro2)-[:FORMS {quantity: 50}]->(rec1)\n", + "CREATE (pro2)-[:FORMS {quantity: 100}]->(rec2)\n", + "CREATE (pro2)-[:FORMS {quantity: 50}]->(rec10)\n", + "CREATE (pro3)-[:FORMS {quantity: 80}]->(rec1)\n", + "CREATE (pro3)-[:FORMS {quantity: 200}]->(rec2)\n", + "CREATE (pro4)-[:FORMS {quantity: 150}]->(rec2)\n", + "CREATE (pro4)-[:FORMS {quantity: 70}]->(rec10)\n", + "CREATE (pro5)-[:FORMS {quantity: 10}]->(rec3)\n", + "CREATE (pro6)-[:FORMS {quantity: 90}]->(rec3)\n", + "CREATE (pro7)-[:FORMS {quantity: 100}]->(rec3)\n", + "CREATE (pro8)-[:FORMS {quantity: 200}]->(rec3)\n", + "CREATE (ing9)-[:FORMS {quantity: 300}]->(rec4)\n", + "CREATE (ing9)-[:FORMS {quantity: 80}]->(rec5)\n", + "CREATE (ing10)-[:FORMS {quantity: 120}]->(rec4)\n", + "CREATE (ing10)-[:FORMS {quantity: 5}]->(rec5)\n", + "CREATE (ing10)-[:FORMS {quantity: 100}]->(rec9)\n", + "CREATE (ing1)-[:FORMS {quantity: 15}]->(rec6)\n", + "CREATE (ing2)-[:FORMS {quantity: 25}]->(rec6)\n", + "CREATE (ing2)-[:FORMS {quantity: 65}]->(rec7)\n", + "CREATE (ing2)-[:FORMS {quantity: 100}]->(rec9)\n", + "CREATE (ing3)-[:FORMS {quantity: 35}]->(rec6)\n", + "CREATE (ing3)-[:FORMS {quantity: 120}]->(rec7)\n", + "CREATE (ing4)-[:FORMS {quantity: 130}]->(rec7)\n", + "CREATE (ing4)-[:FORMS {quantity: 140}]->(rec8)\n", + "CREATE (ing5)-[:FORMS {quantity: 85}]->(rec8)\n", + "CREATE (pro6)-[:SHIPS_WITH]->(shi1)\n", + "CREATE (pro7)-[:SHIPS_WITH]->(shi1)\n", + "CREATE (pro8)-[:SHIPS_WITH]->(shi2)\n", + "CREATE (rec1)-[:PRODUCES {quantity: 1}]->(pro1)\n", + "CREATE (rec2)-[:PRODUCES {quantity: 1}]->(pro2)\n", + "CREATE (rec3)-[:PRODUCES {quantity: 1}]->(pro3)\n", + "CREATE (rec4)-[:PRODUCES {quantity: 1}]->(pro4)\n", + "CREATE (rec5)-[:PRODUCES {quantity: 1}]->(pro5)\n", + "CREATE (rec6)-[:PRODUCES {quantity: 1}]->(pro6)\n", + "CREATE (rec7)-[:PRODUCES {quantity: 1}]->(pro7)\n", + "CREATE (rec8)-[:PRODUCES {quantity: 1}]->(pro8)\n", + "CREATE (rec9)-[:PRODUCES {quantity: 1}]->(pro8)\n", + "CREATE (rec10)-[:PRODUCES {quantity: 1}]->(pro8)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To ensure the data is stored in Memgraph, head to `localhost:3000` and check out Memgraph Lab, a visual user interface. You can see node and relationship count there, explore, query and visualize data. Besides that, you can head over to the Graph Schema tab to check if the imported data is appropriately modeled.\n", + "\n", + "\"drawing\"\n", + "\n", + "Another way of verifying that the database is not empty is by running the following query:" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "94c05fef3dae4b7ea32c44305b3dc150", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Tab(children=(Output(layout=Layout(max_height='600px', max_width='940px', overflow='scroll')), Output(layout=L…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%oc bolt\n", + "MATCH (n)\n", + "RETURN count(n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Great! The data is imported into Memgraph, and we can start analyzing it!\n", + "\n", + "## 4. Supply Chain Analysis\n", + "\n", + "### Acquiring critical hubs in the network with betweenness centrality\n", + "\n", + "If, at some point, a critical path of the pipeline fails, it could mean that some products won't get constructed. Some pipeline failures don't affect as many products and don't need much attention fixing (if the priority isn't high). Some, on the other hand, need immediate attention. \n", + "\n", + "An algorithm like *betweenness centrality* does just that. It detects hubs on the network based on the number of paths that cross a node from all the pairs of nodes in the graph. \n", + "\n", + "By running the query below, we can see that some Intermediate products, if missing, could result in having all of the final products not produced, which is a massive error in the pipeline, and needs extra care to prevent that from happening (by having some alternative measures, additional monitoring of intermediate product production, etc.).\n", + "\n", + "Memgraph's support of betweenness centrality is done through the **betweenness_centrality_online.set()** method, which also works in streaming examples." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d230ed39ce8a463184441573f2f46a3c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Tab(children=(Output(layout=Layout(max_height='600px', max_width='940px', overflow='scroll')), Output(layout=L…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%oc bolt\n", + "CALL betweenness_centrality_online.set() YIELD betweenness_centrality, node\n", + "SET node.centrality = betweenness_centrality;" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "106ce078c18b40b3877a522a39320b17", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Tab(children=(Output(layout=Layout(max_height='600px', max_width='940px', overflow='scroll')), Output(layout=L…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%oc bolt\n", + "MATCH (n)-[r]->(m) RETURN n, r, m;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Get ingredients provided by the supplier\n", + "\n", + "Since a graph database can be the ultimate source of truth between different data sources, it makes sense if all the information about our suppliers is stored in Memgraph.\n", + "\n", + "From there, we can query, for example, which ingredients are supplied by the supplier *Supplissimus*." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4e3a97704eb742e096e3dec50c180480", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Tab(children=(Output(layout=Layout(max_height='600px', max_width='940px', overflow='scroll')), Output(layout=L…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%oc bolt\n", + "MATCH (s:Supplier {name:\"Supplissimus\"})-[r:SUPPLIES]->(i:Ingredient)\n", + "RETURN i;\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Pathfinding for necessary ingredients\n", + "\n", + "We have seen a 1-hop query, which is essentially looking for the nearest neighbors in the network.\n", + "\n", + "Memgraph supports graph traversals, e.g., **Breadth-first search (BFS)**. With it, we can see which ingredients are used to form the product with the ID of 6." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "73a8e07717fe40b3aa2363ef92445811", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Tab(children=(Output(layout=Layout(max_height='600px', max_width='940px', overflow='scroll')), Output(layout=L…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%oc bolt\n", + "MATCH p=(i:Ingredient)-[*BFS]->(f:FinalProduct {id:6})\n", + "RETURN p" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Checking dependencies of the product with ancestors\n", + "\n", + "But traversals are not only a part of graph databases, as whole graph algorithms can be exploited on graph storage like Memgraph. \n", + "\n", + "This query determines what happens before the **:FinalProduct** with the ID 6 gets produced. It is done using the **graph_util.ancestors** procedure captures all the nodes from which a path to the destination node (FinalProduct) exists. " + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2923fd5b429c4c2d907d75d6f9a51a55", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Tab(children=(Output(layout=Layout(max_height='600px', max_width='940px', overflow='scroll')), Output(layout=L…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%oc bolt\n", + "MATCH (f:FinalProduct {id:6})\n", + "CALL graph_util.ancestors(f) YIELD ancestors\n", + "UNWIND ancestors AS ancestor\n", + "RETURN ancestor;\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Ancestors graph\n", + "\n", + "The previous procedure has yielded us all the precedent nodes, but it only means a little since we don't know how they are connected. \n", + "\n", + "To connect the nodes, we can use another MAGE extension procedure called **graph_util.connect_nodes**, which will connect the nodes with corresponding relationships between them." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1822d21a34764816bb56a51455f5b94a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Tab(children=(Output(layout=Layout(max_height='600px', max_width='940px', overflow='scroll')), Output(layout=L…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%oc bolt\n", + "MATCH (f:FinalProduct {id:6})\n", + "CALL graph_util.ancestors(f) YIELD ancestors\n", + "WITH ancestors + [f] AS nodes\n", + "CALL graph_util.connect_nodes(nodes) YIELD connections\n", + "UNWIND nodes + connections AS graph\n", + "RETURN graph;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Checking possible products for production with descendants\n", + "\n", + "We might look at the pipeline from the other direction. From the supplier's view, we can see how many products or operations in the pipeline are affected by him. In case he is unavailable, this information could be helpful to minimize the risk.\n", + "\n", + "Just as with ancestors, we use the procedure **graph_util.descendants**, which yields all the nodes to which a path exists from the source node (supplier *Supplissimus* in this case)." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "94686742f289477a9bc5af2b46fdc3bb", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Tab(children=(Output(layout=Layout(max_height='600px', max_width='940px', overflow='scroll')), Output(layout=L…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%oc bolt\n", + "MATCH (s:Supplier {name: \"Supplissimus\"})\n", + "CALL graph_util.descendants(s) YIELD descendants\n", + "UNWIND descendants AS descendant\n", + "RETURN descendant;\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Descendants graph\n", + "\n", + "We do the same as before and connect the nodes with the **graph_util.connect_nodes** procedure." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f44295f3e66a490b87bc5d76f22211b0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Tab(children=(Output(layout=Layout(max_height='600px', max_width='940px', overflow='scroll')), Output(layout=L…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%oc bolt\n", + "MATCH (s:Supplier {name: \"Supplissimus\"})\n", + "CALL nxalg.descendants(s) YIELD descendants\n", + "WITH descendants + [s] AS nodes\n", + "CALL graph_util.connect_nodes(nodes) YIELD connections\n", + "UNWIND nodes + connections AS graph\n", + "RETURN graph;\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Getting the order of execution with topological sort\n", + "\n", + "There are cases when some operations can't start before others finish, which causes problems because it blocks the pipeline until the process or a job with no dependencies or bottlenecks finishes. Then, some jobs are released and resolved of their dependencies, and they can start executing again. \n", + "\n", + "In graph theory, that's precisely what topological sort does. It sorts the nodes to yield the ones (jobs, operations, or products) that get executed or produced first, followed by those that can start after the previous ones have started.\n", + "\n", + "For sorting the nodes topologically, we will use **graph_util.topological_sort** procedure." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "08b3e918750243f99fb881f9cb9d786c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Tab(children=(Output(layout=Layout(max_height='600px', max_width='940px', overflow='scroll')), Output(layout=L…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%oc bolt\n", + "MATCH p=(r:Recipe)-[*bfs]->(f:FinalProduct)\n", + "WITH project(p) AS graph\n", + "CALL graph_util.topological_sort(graph) YIELD sorted_nodes\n", + "UNWIND sorted_nodes AS nodes\n", + "RETURN nodes.name;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Conclusion\n", + "\n", + "Hopefully, you learned about Memgraph, supply chains and how it's intuitive to analyze them with Cypher queries. If you want to understand why graph databases are the future of network resource optimization, head over to [Memgraph's blog post](https://memgraph.com/blog/graphs-databases-are-the-future-for-network-resource-optimization). For any questions regarding this notebook, Cypher, Memgraph or graphs in general, [join our Discord community](https://www.discord.gg/memgraph). \n", + "\n", + "

⬆️ GO TO TOP ⬆️

" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/test/integration/IntegrationTest.py b/test/integration/IntegrationTest.py index abf7e577..eff66462 100644 --- a/test/integration/IntegrationTest.py +++ b/test/integration/IntegrationTest.py @@ -9,39 +9,67 @@ from graph_notebook.configuration.generate_config import Configuration, AuthModeEnum from graph_notebook.configuration.get_config import get_config -from graph_notebook.neptune.client import ClientBuilder, NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host +from graph_notebook.neptune.client import ( + ClientBuilder, + NEPTUNE_CONFIG_HOST_IDENTIFIERS, + is_allowed_neptune_host, +) from test.integration.NeptuneIntegrationWorkflowSteps import TEST_CONFIG_PATH def setup_client_builder(config: Configuration) -> ClientBuilder: if is_allowed_neptune_host(config.host, NEPTUNE_CONFIG_HOST_IDENTIFIERS): - builder = ClientBuilder() \ - .with_host(config.host) \ - .with_port(config.port) \ - .with_region(config.aws_region) \ - .with_tls(config.ssl) \ - .with_ssl_verify(config.ssl_verify) \ - .with_proxy_host(config.proxy_host) \ - .with_proxy_port(config.proxy_port) \ - .with_sparql_path(config.sparql.path) \ - .with_gremlin_traversal_source(config.gremlin.traversal_source) \ - .with_gremlin_serializer(config.gremlin.message_serializer) \ - .with_neo4j_login(config.neo4j.username, config.neo4j.password, config.neo4j.auth, config.neo4j.database) + builder = ( + ClientBuilder() + .with_host(config.host) + .with_port(config.port) + .with_region(config.aws_region) + .with_tls(config.ssl) + .with_ssl_verify(config.ssl_verify) + .with_proxy_host(config.proxy_host) + .with_proxy_port(config.proxy_port) + .with_sparql_path(config.sparql.path) + .with_gremlin_traversal_source(config.gremlin.traversal_source) + .with_gremlin_serializer(config.gremlin.message_serializer) + .with_neo4j_login( + config.neo4j.username, + config.neo4j.password, + config.neo4j.auth, + config.neo4j.database, + ) + .with_memgraph_login( + config.memgraph.username, + config.memgraph.password, + config.memgraph.auth, + ) + ) if config.auth_mode == AuthModeEnum.IAM: builder = builder.with_iam(get_session()) else: - builder = ClientBuilder() \ - .with_host(config.host) \ - .with_port(config.port) \ - .with_tls(config.ssl) \ - .with_ssl_verify(config.ssl_verify) \ - .with_proxy_host(config.proxy_host) \ - .with_proxy_port(config.proxy_port) \ - .with_sparql_path(config.sparql.path) \ - .with_gremlin_traversal_source(config.gremlin.traversal_source) \ - .with_gremlin_login(config.gremlin.username, config.gremlin.password) \ - .with_gremlin_serializer(config.gremlin.message_serializer) \ - .with_neo4j_login(config.neo4j.username, config.neo4j.password, config.neo4j.auth, config.neo4j.database) + builder = ( + ClientBuilder() + .with_host(config.host) + .with_port(config.port) + .with_tls(config.ssl) + .with_ssl_verify(config.ssl_verify) + .with_proxy_host(config.proxy_host) + .with_proxy_port(config.proxy_port) + .with_sparql_path(config.sparql.path) + .with_gremlin_traversal_source(config.gremlin.traversal_source) + .with_gremlin_login(config.gremlin.username, config.gremlin.password) + .with_gremlin_serializer(config.gremlin.message_serializer) + .with_neo4j_login( + config.neo4j.username, + config.neo4j.password, + config.neo4j.auth, + config.neo4j.database, + ) + .with_memgraph_login( + config.memgraph.username, + config.memgraph.password, + config.memgraph.auth, + ) + ) return builder diff --git a/test/integration/iam/ml/__init__.py b/test/integration/iam/ml/__init__.py index ec2a1f6b..d0c20335 100644 --- a/test/integration/iam/ml/__init__.py +++ b/test/integration/iam/ml/__init__.py @@ -10,21 +10,33 @@ def setup_iam_client(config: Configuration) -> Client: - client = ClientBuilder() \ - .with_host(config.host) \ - .with_port(config.port) \ - .with_region(config.aws_region) \ - .with_tls(config.ssl) \ - .with_ssl_verify(config.ssl_verify) \ - .with_proxy_host(config.proxy_host) \ - .with_proxy_port(config.proxy_port) \ - .with_sparql_path(config.sparql.path) \ - .with_gremlin_traversal_source(config.gremlin.traversal_source) \ - .with_gremlin_login(config.gremlin.username, config.gremlin.password) \ - .with_gremlin_serializer(config.gremlin.message_serializer) \ - .with_neo4j_login(config.neo4j.username, config.neo4j.password, config.neo4j.auth, config.neo4j.database) \ - .with_iam(get_session()) \ + client = ( + ClientBuilder() + .with_host(config.host) + .with_port(config.port) + .with_region(config.aws_region) + .with_tls(config.ssl) + .with_ssl_verify(config.ssl_verify) + .with_proxy_host(config.proxy_host) + .with_proxy_port(config.proxy_port) + .with_sparql_path(config.sparql.path) + .with_gremlin_traversal_source(config.gremlin.traversal_source) + .with_gremlin_login(config.gremlin.username, config.gremlin.password) + .with_gremlin_serializer(config.gremlin.message_serializer) + .with_neo4j_login( + config.neo4j.username, + config.neo4j.password, + config.neo4j.auth, + config.neo4j.database, + ) + .with_memgraph_login( + config.memgraph.username, + config.memgraph.password, + config.memgraph.auth, + ) + .with_iam(get_session()) .build() + ) assert client.host == config.host assert client.port == config.port @@ -39,7 +51,9 @@ def setup_iam_client(config: Configuration) -> Client: assert client.neo4j_username == config.neo4j.username assert client.neo4j_password == config.neo4j.password assert client.neo4j_auth == config.neo4j.auth - assert client.neo4j_database == config.neo4j.database + assert client.memgraph_username == config.memgraph.username + assert client.memgraph_password == config.memgraph.password + assert client.memgraph_auth == config.memgraph.auth assert client.ssl is config.ssl assert client.ssl_verify is config.ssl_verify return client diff --git a/test/integration/iam/notebook/test_open_cypher_graph_notebook.py b/test/integration/iam/notebook/test_open_cypher_graph_notebook.py index 15d28a30..46b18fab 100644 --- a/test/integration/iam/notebook/test_open_cypher_graph_notebook.py +++ b/test/integration/iam/notebook/test_open_cypher_graph_notebook.py @@ -16,101 +16,101 @@ def setUp(self) -> None: @pytest.mark.jupyter @pytest.mark.opencypher def test_opencypher_query(self): - query = '''MATCH(a)-->(b) + query = """MATCH(a)-->(b) RETURN b - LIMIT 1''' + LIMIT 1""" - store_to_var = 'res' - cell = f'''%%oc --store-to {store_to_var} - {query}''' + store_to_var = "res" + cell = f"""%%oc --store-to {store_to_var} + {query}""" self.ip.run_cell(cell) - self.assertFalse('graph_notebook_error' in self.ip.user_ns) + self.assertFalse("graph_notebook_error" in self.ip.user_ns) res = self.ip.user_ns[store_to_var] # TODO: how can we get a look at the objects which were displayed? - assert len(res['results']) == 1 - assert 'b' in res['results'][0] + assert len(res["results"]) == 1 + assert "b" in res["results"][0] @pytest.mark.jupyter @pytest.mark.opencypher def test_opencypher_bolt(self): - query = '''MATCH(a)-->(b) + query = """MATCH(a)-->(b) RETURN b - LIMIT 1''' + LIMIT 1""" - store_to_var = 'res' - cell = f'''%%oc bolt --store-to {store_to_var} - {query}''' + store_to_var = "res" + cell = f"""%%oc bolt --store-to {store_to_var} + {query}""" self.ip.run_cell(cell) - self.assertFalse('graph_notebook_error' in self.ip.user_ns) + self.assertFalse("graph_notebook_error" in self.ip.user_ns) res = self.ip.user_ns[store_to_var] assert len(res) == 1 - assert 'b' in res[0] + assert "b" in res[0] @pytest.mark.jupyter @pytest.mark.opencypher def test_opencypher_query_parameterized_with_var_input(self): expected_league_name = "English Premier League" - query = 'MATCH (l:League {nickname: $LEAGUE_NICKNAME}) RETURN l.name' + query = "MATCH (l:League {nickname: $LEAGUE_NICKNAME}) RETURN l.name" - store_to_var = 'res' - self.ip.user_ns['params_var'] = {'LEAGUE_NICKNAME': 'EPL'} - cell = f'''%%oc --query-parameters params_var --store-to {store_to_var} - {query}''' + store_to_var = "res" + self.ip.user_ns["params_var"] = {"LEAGUE_NICKNAME": "EPL"} + cell = f"""%%oc --query-parameters params_var --store-to {store_to_var} + {query}""" self.ip.run_cell(cell) res = self.ip.user_ns[store_to_var] - assert len(res['results']) == 1 - assert expected_league_name == res['results'][0]['l.name'] + assert len(res["results"]) == 1 + assert expected_league_name == res["results"][0]["l.name"] @pytest.mark.jupyter @pytest.mark.opencypher def test_opencypher_query_parameterized_with_str_input(self): expected_league_name = "English Premier League" - query = 'MATCH (l:League {nickname: $LEAGUE_NICKNAME}) RETURN l.name' + query = "MATCH (l:League {nickname: $LEAGUE_NICKNAME}) RETURN l.name" - store_to_var = 'res' + store_to_var = "res" params_str = '{"LEAGUE_NICKNAME":"EPL"}' - cell = f'''%%oc --query-parameters {params_str} --store-to {store_to_var} - {query}''' + cell = f"""%%oc --query-parameters {params_str} --store-to {store_to_var} + {query}""" self.ip.run_cell(cell) res = self.ip.user_ns[store_to_var] - assert len(res['results']) == 1 - assert expected_league_name == res['results'][0]['l.name'] + assert len(res["results"]) == 1 + assert expected_league_name == res["results"][0]["l.name"] @pytest.mark.jupyter @pytest.mark.opencypher def test_opencypher_query_parameterized_invalid(self): - query = 'MATCH (l:League {nickname: $LEAGUE_NICKNAME}) RETURN l.name' + query = "MATCH (l:League {nickname: $LEAGUE_NICKNAME}) RETURN l.name" - self.ip.user_ns['params_var'] = ['LEAGUE_NICKNAME'] - cell = f'''%%oc --query-parameters params_var - {query}''' + self.ip.user_ns["params_var"] = ["LEAGUE_NICKNAME"] + cell = f"""%%oc --query-parameters params_var + {query}""" self.ip.run_cell(cell) - self.assertTrue('graph_notebook_error' in self.ip.user_ns) + self.assertTrue("graph_notebook_error" in self.ip.user_ns) @pytest.mark.jupyter @pytest.mark.opencypher def test_opencypher_bolt_parameterized(self): expected_league_name = "English Premier League" - query = 'MATCH (l:League {nickname: $LEAGUE_NICKNAME}) RETURN l.name' + query = "MATCH (l:League {nickname: $LEAGUE_NICKNAME}) RETURN l.name" - store_to_var = 'res' + store_to_var = "res" params_var = '{"LEAGUE_NICKNAME":"EPL"}' - cell = f'''%%oc bolt --query-parameters {params_var} --store-to {store_to_var} - {query}''' + cell = f"""%%oc bolt --query-parameters {params_var} --store-to {store_to_var} + {query}""" self.ip.run_cell(cell) - self.assertFalse('graph_notebook_error' in self.ip.user_ns) + self.assertFalse("graph_notebook_error" in self.ip.user_ns) res = self.ip.user_ns[store_to_var] assert len(res) == 1 - assert expected_league_name == res[0]['l.name'] + assert expected_league_name == res[0]["l.name"] @pytest.mark.jupyter def test_load_opencypher_config(self): - config = '''{ + config = """{ "host": "localhost", "port": 8182, "auth_mode": "DEFAULT", @@ -122,7 +122,12 @@ def test_load_opencypher_config(self): "password": "password", "auth": true, "database": "" + }, + "memgraph": { + "username": "", + "password": "", + "auth": false } - }''' + }""" - self.ip.run_cell_magic('graph_notebook_config', '', config) + self.ip.run_cell_magic("graph_notebook_config", "", config) From 34a8a05fd8a93bb51658d412cbd44fa708b19946 Mon Sep 17 00:00:00 2001 From: katarinasupe Date: Fri, 28 Jul 2023 15:34:13 +0200 Subject: [PATCH 2/3] Add link to PR --- ChangeLog.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ChangeLog.md b/ChangeLog.md index bd326c93..9e91f79e 100644 --- a/ChangeLog.md +++ b/ChangeLog.md @@ -4,7 +4,7 @@ Starting with v1.31.6, this file will contain a record of major features and upd ## Upcoming - Added `--explain-type` option to `%%gremlin` ([Link to PR](https://github.com/aws/graph-notebook/pull/503)) -- Added Memgraph as an additional graph database and the supply chain analysis notebook ([Link to PR]()) +- Added Memgraph as an additional graph database and the supply chain analysis notebook ([Link to PR](https://github.com/aws/graph-notebook/pull/513)) ## Release 3.8.2 (June 5, 2023) - New Sample Applications - Healthcare and Life Sciences notebooks ([Link to PR](https://github.com/aws/graph-notebook/pull/484)) From 61b48f175bec85db2fbb3e4ef2e8d1821c656d15 Mon Sep 17 00:00:00 2001 From: katarinasupe Date: Wed, 2 Aug 2023 16:17:56 +0200 Subject: [PATCH 3/3] Add memgraph database config option --- README.md | 1 + .../configuration/generate_config.py | 10 + .../configuration/get_config.py | 5 + src/graph_notebook/magics/graph_magic.py | 1 + src/graph_notebook/neptune/client.py | 7 +- ...-Supply-Chain-Analysis-with-Memgraph.ipynb | 274 +++--------------- test/integration/IntegrationTest.py | 2 + test/integration/iam/ml/__init__.py | 1 + 8 files changed, 64 insertions(+), 237 deletions(-) diff --git a/README.md b/README.md index 5faeef1b..4c88e8c5 100644 --- a/README.md +++ b/README.md @@ -193,6 +193,7 @@ Configuration options can be set using the `%graph_notebook_config` magic comman | sparql | SPARQL connection object | ``` { "path": "sparql" } ``` | string | | gremlin | Gremlin connection object | ``` { "username": "", "password": "", "traversal_source": "g", "message_serializer": "graphsonv3" } ```| string | | neo4j | Neo4J connection object |``` { "username": "neo4j", "password": "password", "auth": true, "database": null } ``` | string | +| memgraph | Memgraph connection object |``` { "username": "", "password": "", "auth": false, "database": "memgraph" } ``` | string | ### Gremlin Server diff --git a/src/graph_notebook/configuration/generate_config.py b/src/graph_notebook/configuration/generate_config.py index d54cd497..baf2457f 100644 --- a/src/graph_notebook/configuration/generate_config.py +++ b/src/graph_notebook/configuration/generate_config.py @@ -19,6 +19,7 @@ DEFAULT_NEO4J_DATABASE, DEFAULT_MEMGRAPH_USERNAME, DEFAULT_MEMGRAPH_PASSWORD, + DEFAULT_MEMGRAPH_DATABASE, NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, false_str_variants, @@ -153,6 +154,7 @@ def __init__( username: str = "", password: str = "", auth: bool = False, + database: str = "", ): """ :param username: login user for the Memgraph endpoint @@ -164,10 +166,13 @@ def __init__( username = DEFAULT_MEMGRAPH_USERNAME if password == "": password = DEFAULT_MEMGRAPH_PASSWORD + if database == "": + database = DEFAULT_MEMGRAPH_DATABASE self.username = username self.password = password self.auth = True if auth in [True, "True", "true", "TRUE"] else False + self.database = database def to_dict(self): return self.__dict__ @@ -437,6 +442,11 @@ def generate_default_config(): help="whether to use auth for Memgraph connections or not [True|False]", default=True, ) + parser.add_argument( + "--memgraph_database", + help="the name of the database to use for Memgraph", + default=DEFAULT_MEMGRAPH_DATABASE, + ) args = parser.parse_args() auth_mode_arg = ( diff --git a/src/graph_notebook/configuration/get_config.py b/src/graph_notebook/configuration/get_config.py index 696a6473..9e7b0467 100644 --- a/src/graph_notebook/configuration/get_config.py +++ b/src/graph_notebook/configuration/get_config.py @@ -23,6 +23,7 @@ DEFAULT_NEO4J_DATABASE, DEFAULT_MEMGRAPH_USERNAME, DEFAULT_MEMGRAPH_PASSWORD, + DEFAULT_MEMGRAPH_DATABASE, ) neptune_params = ["auth_mode", "load_from_s3_arn", "aws_region"] @@ -81,6 +82,10 @@ def get_config_from_dict( print( "Ignoring Memgraph custom authentication, Amazon Neptune does not support this functionality.\n" ) + if memgraph_section.to_dict()["database"] != DEFAULT_MEMGRAPH_DATABASE: + print( + "Ignoring Memgraph custom database, Amazon Neptune does not support multiple databases.\n" + ) config = Configuration( host=data["host"], port=data["port"], diff --git a/src/graph_notebook/magics/graph_magic.py b/src/graph_notebook/magics/graph_magic.py index d3cecdc9..13c1ff82 100644 --- a/src/graph_notebook/magics/graph_magic.py +++ b/src/graph_notebook/magics/graph_magic.py @@ -426,6 +426,7 @@ def _generate_client_from_config(self, config: Configuration): config.memgraph.username, config.memgraph.password, config.memgraph.auth, + config.memgraph.database, ) ) diff --git a/src/graph_notebook/neptune/client.py b/src/graph_notebook/neptune/client.py index 1bfe6678..fa78b0bd 100644 --- a/src/graph_notebook/neptune/client.py +++ b/src/graph_notebook/neptune/client.py @@ -39,6 +39,7 @@ DEFAULT_NEO4J_DATABASE = DEFAULT_DATABASE DEFAULT_MEMGRAPH_USERNAME = "" DEFAULT_MEMGRAPH_PASSWORD = "" +DEFAULT_MEMGRAPH_DATABASE = "memgraph" NEPTUNE_SERVICE_NAME = "neptune-db" logger = logging.getLogger("client") @@ -191,6 +192,7 @@ def __init__( neo4j_database: str = DEFAULT_NEO4J_DATABASE, memgraph_username: str = DEFAULT_MEMGRAPH_USERNAME, memgraph_password: str = DEFAULT_MEMGRAPH_PASSWORD, + memgraph_database: str = DEFAULT_MEMGRAPH_DATABASE, memgraph_auth: bool = False, auth=None, session: Session = None, @@ -1058,10 +1060,13 @@ def with_neo4j_login(self, username: str, password: str, auth: bool, database: s self.args["neo4j_database"] = database return ClientBuilder(self.args) - def with_memgraph_login(self, username: str, password: str, auth: bool): + def with_memgraph_login( + self, username: str, password: str, auth: bool, database: str + ): self.args["memgraph_username"] = username self.args["memgraph_password"] = password self.args["memgraph_auth"] = auth + self.args["memgraph_database"] = database return ClientBuilder(self.args) def with_tls(self, tls: bool): diff --git a/src/graph_notebook/notebooks/01-Getting-Started/06-Supply-Chain-Analysis-with-Memgraph.ipynb b/src/graph_notebook/notebooks/01-Getting-Started/06-Supply-Chain-Analysis-with-Memgraph.ipynb index a79c7516..26d061e9 100644 --- a/src/graph_notebook/notebooks/01-Getting-Started/06-Supply-Chain-Analysis-with-Memgraph.ipynb +++ b/src/graph_notebook/notebooks/01-Getting-Started/06-Supply-Chain-Analysis-with-Memgraph.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -45,66 +46,22 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": { "scrolled": true }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "set notebook config to:\n", - "{\n", - " \"host\": \"localhost\",\n", - " \"port\": 7687,\n", - " \"proxy_host\": \"\",\n", - " \"proxy_port\": 8182,\n", - " \"ssl\": false,\n", - " \"ssl_verify\": true,\n", - " \"sparql\": {\n", - " \"path\": \"\"\n", - " },\n", - " \"gremlin\": {\n", - " \"traversal_source\": \"g\",\n", - " \"username\": \"\",\n", - " \"password\": \"\",\n", - " \"message_serializer\": \"graphsonv3\"\n", - " },\n", - " \"neo4j\": {\n", - " \"username\": \"neo4j\",\n", - " \"password\": \"password\",\n", - " \"auth\": true,\n", - " \"database\": null\n", - " }\n", - "}\n" - ] - }, - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "%%graph_notebook_config\n", "{\n", " \"host\": \"localhost\",\n", " \"port\": 7687,\n", - " \"ssl\": false,\n", - " \"neo4j\": {\n", - " \"auth\": true,\n", - " \"database\": \"\"\n", - " }\n", + " \"ssl\": false\n", "}" ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -117,26 +74,11 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": null, "metadata": { "scrolled": false }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "1937e41be1fc42a5b1f7c64dd0c7527c", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Tab(children=(Output(layout=Layout(max_height='600px', max_width='940px', overflow='scroll')), Output(layout=L…" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "%%oc bolt\n", "CREATE (sup1:Supplier {id: 1, name: \"Supplissimus\", centrality: 0.027920624240525559})\n", @@ -225,6 +167,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -237,26 +180,11 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": null, "metadata": { "scrolled": true }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "94c05fef3dae4b7ea32c44305b3dc150", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Tab(children=(Output(layout=Layout(max_height='600px', max_width='940px', overflow='scroll')), Output(layout=L…" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "%%oc bolt\n", "MATCH (n)\n", @@ -264,6 +192,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -284,26 +213,11 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": { "scrolled": true }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d230ed39ce8a463184441573f2f46a3c", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Tab(children=(Output(layout=Layout(max_height='600px', max_width='940px', overflow='scroll')), Output(layout=L…" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "%%oc bolt\n", "CALL betweenness_centrality_online.set() YIELD betweenness_centrality, node\n", @@ -312,30 +226,16 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "106ce078c18b40b3877a522a39320b17", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Tab(children=(Output(layout=Layout(max_height='600px', max_width='940px', overflow='scroll')), Output(layout=L…" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "%%oc bolt\n", "MATCH (n)-[r]->(m) RETURN n, r, m;" ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -348,26 +248,11 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": null, "metadata": { "scrolled": false }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "4e3a97704eb742e096e3dec50c180480", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Tab(children=(Output(layout=Layout(max_height='600px', max_width='940px', overflow='scroll')), Output(layout=L…" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "%%oc bolt\n", "MATCH (s:Supplier {name:\"Supplissimus\"})-[r:SUPPLIES]->(i:Ingredient)\n", @@ -375,6 +260,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -387,26 +273,11 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": { "scrolled": false }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "73a8e07717fe40b3aa2363ef92445811", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Tab(children=(Output(layout=Layout(max_height='600px', max_width='940px', overflow='scroll')), Output(layout=L…" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "%%oc bolt\n", "MATCH p=(i:Ingredient)-[*BFS]->(f:FinalProduct {id:6})\n", @@ -414,6 +285,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -426,24 +298,9 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "2923fd5b429c4c2d907d75d6f9a51a55", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Tab(children=(Output(layout=Layout(max_height='600px', max_width='940px', overflow='scroll')), Output(layout=L…" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "%%oc bolt\n", "MATCH (f:FinalProduct {id:6})\n", @@ -453,6 +310,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -465,24 +323,9 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "1822d21a34764816bb56a51455f5b94a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Tab(children=(Output(layout=Layout(max_height='600px', max_width='940px', overflow='scroll')), Output(layout=L…" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "%%oc bolt\n", "MATCH (f:FinalProduct {id:6})\n", @@ -494,6 +337,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -506,26 +350,11 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": { "scrolled": true }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "94686742f289477a9bc5af2b46fdc3bb", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Tab(children=(Output(layout=Layout(max_height='600px', max_width='940px', overflow='scroll')), Output(layout=L…" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "%%oc bolt\n", "MATCH (s:Supplier {name: \"Supplissimus\"})\n", @@ -535,6 +364,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -545,26 +375,11 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "metadata": { "scrolled": true }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "f44295f3e66a490b87bc5d76f22211b0", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Tab(children=(Output(layout=Layout(max_height='600px', max_width='940px', overflow='scroll')), Output(layout=L…" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "%%oc bolt\n", "MATCH (s:Supplier {name: \"Supplissimus\"})\n", @@ -576,6 +391,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -590,26 +406,11 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": null, "metadata": { "scrolled": true }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "08b3e918750243f99fb881f9cb9d786c", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Tab(children=(Output(layout=Layout(max_height='600px', max_width='940px', overflow='scroll')), Output(layout=L…" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "%%oc bolt\n", "MATCH p=(r:Recipe)-[*bfs]->(f:FinalProduct)\n", @@ -620,6 +421,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -647,7 +449,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.9.16" } }, "nbformat": 4, diff --git a/test/integration/IntegrationTest.py b/test/integration/IntegrationTest.py index eff66462..9f051f05 100644 --- a/test/integration/IntegrationTest.py +++ b/test/integration/IntegrationTest.py @@ -41,6 +41,7 @@ def setup_client_builder(config: Configuration) -> ClientBuilder: config.memgraph.username, config.memgraph.password, config.memgraph.auth, + config.memgraph.database, ) ) if config.auth_mode == AuthModeEnum.IAM: @@ -68,6 +69,7 @@ def setup_client_builder(config: Configuration) -> ClientBuilder: config.memgraph.username, config.memgraph.password, config.memgraph.auth, + config.memgraph.database, ) ) diff --git a/test/integration/iam/ml/__init__.py b/test/integration/iam/ml/__init__.py index d0c20335..621268aa 100644 --- a/test/integration/iam/ml/__init__.py +++ b/test/integration/iam/ml/__init__.py @@ -33,6 +33,7 @@ def setup_iam_client(config: Configuration) -> Client: config.memgraph.username, config.memgraph.password, config.memgraph.auth, + config.memgraph.database, ) .with_iam(get_session()) .build()