Up

sparktk.frame.schema module

# vim: set encoding=utf-8

#  Copyright (c) 2016 Intel Corporation 
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

from sparktk.dtypes import dtypes, matrix

def jvm_scala_schema(sc):
    return sc._jvm.org.trustedanalytics.sparktk.frame.SchemaHelper

def get_schema_for_columns(schema, selected_columns):
    indices = get_indices_for_selected_columns(schema, selected_columns)
    return [schema[i] for i in indices]

def get_indices_for_selected_columns(schema, selected_columns):
    indices = []
    schema_columns = [col[0] for col in schema]
    for column in selected_columns:
        try:
            indices.append(schema_columns.index(column))
        except:
            raise ValueError("Invalid column name %s provided"
                             ", please choose from: (%s)" % (column, ",".join(schema_columns)))

    return indices

def schema_to_scala(sc, python_schema):
    list_of_list_of_str_schema = map(lambda t: [t[0], dtypes.to_string(t[1])], python_schema)  # convert dtypes to strings
    return jvm_scala_schema(sc).pythonToScala(list_of_list_of_str_schema)

def schema_to_python(sc, scala_schema):
    list_of_list_of_str_schema = jvm_scala_schema(sc).scalaToPython(scala_schema)
    return [(name, dtypes.get_from_string(dtype)) for name, dtype in list_of_list_of_str_schema]


def validate(python_schema):
    """
    Raises an error if the given schema is not a tuple or list of tuples of the type (column_name, data type),
    where column_name must be a string followed by supported data type.
    """
    if not isinstance(python_schema, list):
        python_schema = [python_schema]
    unique_names = set()
    for item in python_schema:
        if not isinstance(item, tuple):
            raise ValueError("schema expected to contain tuples, encountered type %s" % type(item))
        if not isinstance(item[0], basestring):
            raise ValueError("first entry in schema tuple should be a string, received type %s: %s" %
                             (type(item[0]), str(item[0])))
        if len(item) != 2:
            raise ValueError("schema tuples should have 2 items (column name and type), but found tuple with length: %s" %
                             len(item))
        unique_names.add(item[0])
    if len(unique_names) != len(python_schema):
        names = map(lambda x: x[0], python_schema)
        for u in unique_names:
            names.remove(u)
        raise ValueError("schema has duplicate column names: %s" % str(names))


def validate_is_mergeable(tc, *python_schema):
    """
    Raises an error if the column names in the given schema conflict
    """

    scala_schema_list = []
    for schema in python_schema:
        if not isinstance(schema, list):
            schema = [schema]
        scala_schema_list.append(schema_to_scala(tc.sc, schema))

    jvm_scala_schema(tc.sc).validateIsMergeable(tc.jutils.convert.to_scala_list(scala_schema_list))

Functions

def get_indices_for_selected_columns(

schema, selected_columns)

def get_indices_for_selected_columns(schema, selected_columns):
    indices = []
    schema_columns = [col[0] for col in schema]
    for column in selected_columns:
        try:
            indices.append(schema_columns.index(column))
        except:
            raise ValueError("Invalid column name %s provided"
                             ", please choose from: (%s)" % (column, ",".join(schema_columns)))

    return indices

def get_schema_for_columns(

schema, selected_columns)

def get_schema_for_columns(schema, selected_columns):
    indices = get_indices_for_selected_columns(schema, selected_columns)
    return [schema[i] for i in indices]

def jvm_scala_schema(

sc)

def jvm_scala_schema(sc):
    return sc._jvm.org.trustedanalytics.sparktk.frame.SchemaHelper

def schema_to_python(

sc, scala_schema)

def schema_to_python(sc, scala_schema):
    list_of_list_of_str_schema = jvm_scala_schema(sc).scalaToPython(scala_schema)
    return [(name, dtypes.get_from_string(dtype)) for name, dtype in list_of_list_of_str_schema]

def schema_to_scala(

sc, python_schema)

def schema_to_scala(sc, python_schema):
    list_of_list_of_str_schema = map(lambda t: [t[0], dtypes.to_string(t[1])], python_schema)  # convert dtypes to strings
    return jvm_scala_schema(sc).pythonToScala(list_of_list_of_str_schema)

def validate(

python_schema)

Raises an error if the given schema is not a tuple or list of tuples of the type (column_name, data type), where column_name must be a string followed by supported data type.

def validate(python_schema):
    """
    Raises an error if the given schema is not a tuple or list of tuples of the type (column_name, data type),
    where column_name must be a string followed by supported data type.
    """
    if not isinstance(python_schema, list):
        python_schema = [python_schema]
    unique_names = set()
    for item in python_schema:
        if not isinstance(item, tuple):
            raise ValueError("schema expected to contain tuples, encountered type %s" % type(item))
        if not isinstance(item[0], basestring):
            raise ValueError("first entry in schema tuple should be a string, received type %s: %s" %
                             (type(item[0]), str(item[0])))
        if len(item) != 2:
            raise ValueError("schema tuples should have 2 items (column name and type), but found tuple with length: %s" %
                             len(item))
        unique_names.add(item[0])
    if len(unique_names) != len(python_schema):
        names = map(lambda x: x[0], python_schema)
        for u in unique_names:
            names.remove(u)
        raise ValueError("schema has duplicate column names: %s" % str(names))

def validate_is_mergeable(

tc, *python_schema)

Raises an error if the column names in the given schema conflict

def validate_is_mergeable(tc, *python_schema):
    """
    Raises an error if the column names in the given schema conflict
    """

    scala_schema_list = []
    for schema in python_schema:
        if not isinstance(schema, list):
            schema = [schema]
        scala_schema_list.append(schema_to_scala(tc.sc, schema))

    jvm_scala_schema(tc.sc).validateIsMergeable(tc.jutils.convert.to_scala_list(scala_schema_list))