diff --git a/sparkmagic/sparkmagic/magics/remotesparkmagics.py b/sparkmagic/sparkmagic/magics/remotesparkmagics.py index fcd41e29..fec06bd5 100644 --- a/sparkmagic/sparkmagic/magics/remotesparkmagics.py +++ b/sparkmagic/sparkmagic/magics/remotesparkmagics.py @@ -30,7 +30,11 @@ from sparkmagic.controllerwidget.magicscontrollerwidget import MagicsControllerWidget from sparkmagic.livyclientlib.endpoint import Endpoint from sparkmagic.magics.sparkmagicsbase import SparkMagicBase -from sparkmagic.livyclientlib.exceptions import handle_expected_exceptions +from sparkmagic.livyclientlib.exceptions import ( + handle_expected_exceptions, + wrap_unexpected_exceptions, + BadUserDataException +) @magics_class @@ -326,6 +330,63 @@ def _print_local_info(self): ) ) + @magic_arguments() + @argument( + "-i", + "--input", + type=str, + default=None, + help="If present, indicated variable will be stored in variable" + " in Spark's context.", + ) + @argument( + "-t", + "--vartype", + type=str, + default="str", + help="Optionally specify the type of input variable. " + "Available: 'str' - string(default) or 'df' - Pandas DataFrame", + ) + @argument( + "-n", + "--varname", + type=str, + default=None, + help="Optionally specify the custom name for the input variable.", + ) + @argument( + "-m", + "--maxrows", + type=int, + default=2500, + help="Maximum number of rows that will be pulled back " + "from the local dataframe", + ) + @line_magic + @needs_local_scope + @wrap_unexpected_exceptions + @handle_expected_exceptions + def send_to_spark(self, line, local_ns=None): + """Magic to send a variable to spark cluster. + + Usage: %send_to_spark -i variable -t str -n var + + -i VAR_NAME: Local Pandas DataFrame(or String) of name VAR_NAME will be available in the %%spark context as a + Spark dataframe(or String) with the same name. + -t TYPE: Specifies the type of variable passed as -i. Available options are: + `str` for string and `df` for Pandas DataFrame. Optional, defaults to `str`. + -n NAME: Custom name of variable passed as -i. Optional, defaults to -i variable name. + -m MAXROWS: Maximum amount of Pandas rows that will be sent to Spark. Defaults to 2500. + + """ + args = parse_argstring_or_throw(self.send_to_spark, line) + + if not args.input: + raise BadUserDataException("-i param not provided.") + + self.do_send_to_spark( + "", args.input, args.vartype, args.varname, args.maxrows, None + ) def load_ipython_extension(ip): ip.register_magics(RemoteSparkMagics) diff --git a/sparkmagic/sparkmagic/tests/test_remotesparkmagics.py b/sparkmagic/sparkmagic/tests/test_remotesparkmagics.py index bb2b27a0..f2e6d0eb 100644 --- a/sparkmagic/sparkmagic/tests/test_remotesparkmagics.py +++ b/sparkmagic/sparkmagic/tests/test_remotesparkmagics.py @@ -14,6 +14,9 @@ from sparkmagic.livyclientlib.exceptions import * from sparkmagic.livyclientlib.sqlquery import SQLQuery from sparkmagic.livyclientlib.sparkstorecommand import SparkStoreCommand +import pandas as pd + +import unittest magic = None spark_controller = None @@ -585,6 +588,24 @@ def test_run_sql_command_knows_how_to_be_quiet(): ) assert result is None +def test_send_to_spark_with_str_variable(): + magic.do_send_to_spark = MagicMock() + line = "-i str_var_name -n var_name_in_spark" + result = magic.send_to_spark(line) + magic.do_send_to_spark.assert_called_once_with( + '', 'str_var_name', 'str', 'var_name_in_spark', 2500, None + ) + +def test_send_to_spark_with_pandas_variable(): + magic.do_send_to_spark = MagicMock() + df = pd.DataFrame({'key': ['val1', 'val2']}) + shell['df_var_name'] = df + line = "-i df_var_name -t df -n var_name_in_spark" + result = magic.send_to_spark(line) + magic.do_send_to_spark.assert_called_once_with( + '', 'df_var_name', 'df', 'var_name_in_spark', 2500, None + ) + def test_logs_subcommand(): get_logs_method = MagicMock() @@ -623,4 +644,4 @@ def test_logs_exception(): assert result is None ipython_display.send_error.assert_called_once_with( EXPECTED_ERROR_MSG.format(get_logs_method.side_effect) - ) + ) \ No newline at end of file