spark-issues mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From "Jean-Marc Montanier (Jira)" <j...@apache.org>
Subject [jira] [Created] (SPARK-30397) [pyspark] Writer applied to custom model changes type of keys' dict from int to str
Date Tue, 31 Dec 2019 09:00:00 GMT
Jean-Marc Montanier created SPARK-30397:
-------------------------------------------

             Summary: [pyspark] Writer applied to custom model changes type of keys' dict
from int to str
                 Key: SPARK-30397
                 URL: https://issues.apache.org/jira/browse/SPARK-30397
             Project: Spark
          Issue Type: Bug
          Components: PySpark
    Affects Versions: 2.4.4
            Reporter: Jean-Marc Montanier


Hello,

 

I have a custom model that I'm trying to persist. Within this custom model there is a python
dict mapping from int to int. When the model is saved (with write().save('path')), the keys
of the dict are modified from int to str.

 

You can find bellow a code to reproduce the issue:
{code:python}
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: Jean-Marc Montanier
@date: 2019/12/31
"""

from pyspark.sql import SparkSession

from pyspark import keyword_only
from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml import Estimator, Model
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from pyspark.ml.param import Param, Params
from pyspark.ml.param.shared import HasInputCol, HasOutputCol
from pyspark.sql.types import IntegerType
from pyspark.sql.functions import udf


spark = SparkSession \
    .builder \
    .appName("ImputeNormal") \
    .getOrCreate()


class CustomFit(Estimator,
                HasInputCol,
                HasOutputCol,
                DefaultParamsReadable,
                DefaultParamsWritable,
                ):
    @keyword_only
    def __init__(self, inputCol="inputCol", outputCol="outputCol"):
        super(CustomFit, self).__init__()

        self._setDefault(inputCol="inputCol", outputCol="outputCol")
        kwargs = self._input_kwargs
        self.setParams(**kwargs)

    @keyword_only
    def setParams(self, inputCol="inputCol", outputCol="outputCol"):
        """
        setParams(self, inputCol="inputCol", outputCol="outputCol")
        """
        kwargs = self._input_kwargs
        self._set(**kwargs)
        return self

    def _fit(self, data):
        inputCol = self.getInputCol()
        outputCol = self.getOutputCol()

        categories = data.where(data[inputCol].isNotNull()) \
            .groupby(inputCol) \
            .count() \
            .orderBy("count", ascending=False) \
            .limit(2)
        categories = dict(categories.toPandas().set_index(inputCol)["count"])
        for cat in categories:
            categories[cat] = int(categories[cat])

        return CustomModel(categories=categories,
                           input_col=inputCol,
                           output_col=outputCol)


class CustomModel(Model,
                  DefaultParamsReadable,
                  DefaultParamsWritable):

    input_col = Param(Params._dummy(), "input_col", "Name of the input column")
    output_col = Param(Params._dummy(), "output_col", "Name of the output column")
    categories = Param(Params._dummy(), "categories", "Top categories")

    def __init__(self, categories: dict = None, input_col="input_col", output_col="output_col"):
        super(CustomModel, self).__init__()

        self._set(categories=categories, input_col=input_col, output_col=output_col)

    def get_output_col(self) -> str:
        """
        output_col getter
        :return:
        """
        return self.getOrDefault(self.output_col)

    def get_input_col(self) -> str:
        """
        input_col getter
        :return:
        """
        return self.getOrDefault(self.input_col)

    def get_categories(self):
        """
        categories getter
        :return:
        """
        return self.getOrDefault(self.categories)

    def _transform(self, data):
        input_col = self.get_input_col()
        output_col = self.get_output_col()
        categories = self.get_categories()

        def get_cat(val):
            if val is None:
                return -1
            if val not in categories:
                return -1
            return int(categories[val])

        get_cat_udf = udf(get_cat, IntegerType())

        df = data.withColumn(output_col,
                             get_cat_udf(input_col))

        return df


def test_without_write():
    fit_df = spark.createDataFrame([[10]] * 5 + [[11]] * 4 + [[12]] * 3 + [[None]] * 2, ['input'])
    custom_fit = CustomFit(inputCol='input', outputCol='output')
    pipeline = Pipeline(stages=[custom_fit])
    pipeline_model = pipeline.fit(fit_df)

    print("Categories: {}".format(pipeline_model.stages[0].get_categories()))

    transform_df = spark.createDataFrame([[10]] * 2 + [[11]] * 2 + [[12]] * 2 + [[None]] *
2, ['input'])
    test = pipeline_model.transform(transform_df)
    test.show()  # This output is the expected output


def test_with_write():
    fit_df = spark.createDataFrame([[10]] * 5 + [[11]] * 4 + [[12]] * 3 + [[None]] * 2, ['input'])
    custom_fit = CustomFit(inputCol='input', outputCol='output')
    pipeline = Pipeline(stages=[custom_fit])
    pipeline_model = pipeline.fit(fit_df)

    print("Categories: {}".format(pipeline_model.stages[0].get_categories()))

    pipeline_model.write().save('tmp')
    loaded_model = PipelineModel.load('tmp')
    # We can see that the type of the keys is know str instead of int
    print("Categories: {}".format(loaded_model.stages[0].get_categories()))

    transform_df = spark.createDataFrame([[10]] * 2 + [[11]] * 2 + [[12]] * 2 + [[None]] *
2, ['input'])
    test = loaded_model.transform(transform_df)
    test.show()  # We can see that the output does not match the expected output


if __name__ == "__main__":
    test_without_write()
    test_with_write()

{code}
 



--
This message was sent by Atlassian Jira
(v8.3.4#803005)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@spark.apache.org
For additional commands, e-mail: issues-help@spark.apache.org


Mime
View raw message