madlib-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ri...@apache.org
Subject [madlib] branch master updated: K-NN: Add kd-tree method for approximate knn
Date Thu, 21 Feb 2019 00:43:13 GMT
This is an automated email from the ASF dual-hosted git repository.

riyer pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git


The following commit(s) were added to refs/heads/master by this push:
     new 5e601fb  K-NN: Add kd-tree method for approximate knn
5e601fb is described below

commit 5e601fbdb4c6423c148f8bdfead0a9988f31800d
Author: Orhan Kislal <okislal@apache.org>
AuthorDate: Wed Feb 20 16:33:46 2019 -0800

    K-NN: Add kd-tree method for approximate knn
    
    JIRA: MADLIB-1061
    
    This commit adds a kd-tree option to the 'knn' function. A kd-tree is
    used to reduce the search space to find nearest neighbors. The method
    implemented here does not produce the complete kd-tree, instead it
    allows the user to specify a maximum depth for the binary tree.
    
    Additional changes:
    - Add function to clean madlib views
    - Move k-nn out of 'Early Stage Development'
    
    Closes #352
    
    Co-authored-by: Rahul Iyer <riyer@apache.org>
    Co-authored-by: Frank McQuillan <fmcquillan@pivotal.io>
---
 doc/design/design.tex                              |   1 +
 doc/design/figures/2d_kdtree.pdf                   | Bin 0 -> 10652 bytes
 doc/design/modules/knn.tex                         | 146 +++++++
 doc/literature.bib                                 |  11 +
 doc/mainpage.dox.in                                |   2 +-
 src/ports/postgres/modules/knn/knn.py_in           | 480 +++++++++++++++++----
 src/ports/postgres/modules/knn/knn.sql_in          | 249 +++++++++--
 src/ports/postgres/modules/knn/test/knn.sql_in     | 287 +++++++++---
 src/ports/postgres/modules/utilities/admin.py_in   |  22 +
 .../postgres/modules/utilities/utilities.py_in     |   1 -
 .../postgres/modules/utilities/utilities.sql_in    |   8 +
 11 files changed, 1033 insertions(+), 174 deletions(-)

diff --git a/doc/design/design.tex b/doc/design/design.tex
index e9ed7b8..6772f89 100644
--- a/doc/design/design.tex
+++ b/doc/design/design.tex
@@ -231,6 +231,7 @@
 \input{modules/SVM}
 \input{modules/graph}
 \input{modules/neural-network}
+\input{modules/knn}
 \printbibliography
 
 \end{document}
diff --git a/doc/design/figures/2d_kdtree.pdf b/doc/design/figures/2d_kdtree.pdf
new file mode 100644
index 0000000..062ae23
Binary files /dev/null and b/doc/design/figures/2d_kdtree.pdf differ
diff --git a/doc/design/modules/knn.tex b/doc/design/modules/knn.tex
new file mode 100644
index 0000000..71af411
--- /dev/null
+++ b/doc/design/modules/knn.tex
@@ -0,0 +1,146 @@
+% Licensed to the Apache Software Foundation (ASF) under one
+% or more contributor license agreements.  See the NOTICE file
+% distributed with this work for additional information
+% regarding copyright ownership.  The ASF licenses this file
+% to you under the Apache License, Version 2.0 (the
+% "License"); you may not use this file except in compliance
+% with the License.  You may obtain a copy of the License at
+
+%   http://www.apache.org/licenses/LICENSE-2.0
+
+% Unless required by applicable law or agreed to in writing,
+% software distributed under the License is distributed on an
+% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+% KIND, either express or implied.  See the License for the
+% specific language governing permissions and limitations
+% under the License.
+
+!TEX root = ../design.tex
+
+
+\chapter[k Nearest Neighbors]{k Nearest Neighbors}
+
+\begin{moduleinfo}
+\item[Authors] \href{mailto:okislal@pivotal.io}{Orhan Kislal}
+
+\item[History]
+	\begin{modulehistory}
+		\item[v0.1] Initial version: knn and kd-tree.
+	\end{modulehistory}
+\end{moduleinfo}
+
+
+% Abstract. What is the problem we want to solve?
+\section{Introduction} % (fold)
+\label{sec:knn_introduction}
+
+\emph{Some notes and figures in this section are borrowed from \cite{medium_knn} and \cite{point_knn}}.
+
+K-nearest neighbors (KNN) is one of the most commonly used learning
+algorithms. The goal of knn is to find a number (k) of training data points
+closest to the test point. These neighbors can be used to predict labels via
+classification or regression.
+
+KNN does not have a training phase like the most of learning techniques. It
+does not create a model to generalize the data, instead the algorithm uses the
+whole training dataset (or a specific subset of it).
+
+KNN can be used for classification, the output is a class membership (a
+discrete value). An object is classified by a majority vote of its neighbors,
+with the object being assigned to the class most common among its k nearest
+neighbors. It can also be used for regression, output is the value for the
+object (predicts continuous values). This value is the average (or median) of
+the values of its k nearest neighbors.
+
+\section{Implementation Details}
+
+The basic KNN implementation depends on the table join between the training dataset and the test dataset.
+
+\begin{sql}
+	(SELECT test_id,
+            train_id,
+            fn_dist(train_col_name, test_col_name) AS dist,
+            label
+    FROM train_table, test_table) AS knn_sub
+\end{sql}
+
+Once we have the distance between every train - test pair, the algorithm picks the k smallest values.
+
+\begin{sql}
+	SELECT row_number() OVER
+        (PARTITION BY test_id ORDER BY dist) AS r,
+        test_id,
+        train_id,
+        label
+	FROM knn_sub
+	WHERE r <= k
+\end{sql}
+
+Finally, the prediction is completed based on the labels of the selected
+training points for each test point.
+
+\section{Enabling KD-tree}
+
+One of the major shortcomings of KNN is the fact that it is computationally
+expensive. In addition, there is no training phase; which means every single
+prediction will have to compute the full table join. One of the ways to
+improve the performance is to reduce the search space for test points. Kd-tree
+option is developed to enable trading the accuracy of the output with higher
+performance by reducing the neighbor search space.
+
+Kd-trees are used for organizing data in k dimensions. It is constructed like
+a binary search tree where each level of the tree is using a specific
+dimension for finding splits.
+
+
+\begin{figure}[h]
+	\centering
+	\includegraphics[width=0.9\textwidth]{figures/2d_kdtree.pdf}
+\caption{A 2D kd-tree of depth 3}
+\label{kdd:2d_kdtree}
+\end{figure}
+
+A kd-tree is constructed by finding the median value of the data in a
+particular dimension and separating the data into two sections based on this
+value. This process is repeated a number of times to construct smaller
+regions. Once the kd-tree is prepared, it can be used by any test point to
+find its assigned region and this fragmentation can be used for limiting the
+search space for nearest neighbors.
+
+Once we have the kd-tree regions and their borders, we find the associated
+regions for the test points. This gives us the first region to search for
+nearest neighbors. In addition, we allow the user to request for multiple
+regions to search. This means we have to decide which additional regions to
+include in our search. We implemented a backtracking algorithm to find these
+regions. The core idea is to find the closest border for each test point and
+select the region on the other side of the border. Note that points that
+reside in the same region might have different secondary (or tertiary, etc.)
+regions. Consider the tree at Figure~\ref{kdd:2d_kdtree}. A test point at $<5
+, 2>$ is in the same region as $<3 , 3.9>$. However, their closest borders and
+the associated secondary regions are wildly different. In addition, consider
+$<3 , 3.9>$ and $<6 , 3.9>$. They both have the same border as their closest
+one ($y=4$). However, their closest regions do differ. To make sure that we
+get the correct region, the following scheme is implemented. For a given point
+$P$, we find the closest border, $dim[i] = x$ and $P$'s relative position to
+it ($pos$ = $-1$ for lower and $+1$ for higher). We conjure a new point that
+consists of the same values as the test point in every dimension except $i$.
+For $dim[i]$, we set the value to $x-pos*\epsilon$. Finally, we use the
+existing kd-tree to find this new point's assigned region. This region is our
+expansion target for the point $P$. We repeat this process with the next
+closest border as requested by the user.
+
+The knn algorithm does not change significantly with the addition of regions.
+Assuming that the training and test datasets have their region information
+stored in the tables, the only necessary change is ensuring that the table
+join uses these region ids to limit the search space.
+
+
+\begin{sql}
+	(SELECT test_id,
+            train_id,
+            fn_dist(train_col_name, test_col_name) AS dist,
+            label
+    FROM train_table, test_table
+    WHERE train_table.region_id = test_table.region_id
+    ) AS knn_sub
+\end{sql}
diff --git a/doc/literature.bib b/doc/literature.bib
index 3c07260..5aa19ab 100644
--- a/doc/literature.bib
+++ b/doc/literature.bib
@@ -986,3 +986,14 @@ Applied Survival Analysis},
     Title = {{TRAINING RECURRENT NEURAL NETWORKS}},
     Author = {{Ilya Sutskever}}
 }
+
+@misc{medium_knn,
+    Url = {https://medium.com/@adi.bronshtein/a-quick-introduction-to-k-nearest-neighbors-algorithm-62214cea29c7},
+    Title = {{A quick introduction to k nearest neighbors algorithm}},
+}
+
+@misc{point_knn,
+    Url = {http://pointclouds.org/documentation/tutorials/kdtree_search.php},
+    Title = {{How to use a KdTree to search}},
+}
+
diff --git a/doc/mainpage.dox.in b/doc/mainpage.dox.in
index 5568da6..c8b308d 100644
--- a/doc/mainpage.dox.in
+++ b/doc/mainpage.dox.in
@@ -191,6 +191,7 @@ complete matrix stored as a distributed table.
 @details Methods to perform a variety of supervised learning tasks.
 @{
     @defgroup grp_crf Conditional Random Field
+    @defgroup grp_knn k-Nearest Neighbors
     @defgroup grp_nn Neural Network
     @defgroup grp_regml Regression Models
     @brief A collection of methods for modeling conditional expectation of a response variable.
@@ -291,7 +292,6 @@ Interface and implementation are subject to change.
     @{
         @defgroup grp_minibatch_preprocessing_dl Mini-Batch Preprocessor for Deep Learning
     @}
-    @defgroup grp_knn k-Nearest Neighbors
     @defgroup grp_bayes Naive Bayes Classification
     @defgroup grp_sample Random Sampling
 @}
diff --git a/src/ports/postgres/modules/knn/knn.py_in b/src/ports/postgres/modules/knn/knn.py_in
index 4db7ac1..bf64352 100644
--- a/src/ports/postgres/modules/knn/knn.py_in
+++ b/src/ports/postgres/modules/knn/knn.py_in
@@ -27,27 +27,38 @@
 """
 
 import plpy
-from utilities.validate_args import input_tbl_valid, output_tbl_valid
-from utilities.validate_args import cols_in_tbl_valid
-from utilities.validate_args import is_col_array
-from utilities.validate_args import array_col_has_no_null
-from utilities.validate_args import get_expr_type
+import copy
+from collections import defaultdict
+from math import log
+from utilities.control import MinWarning
+from utilities.utilities import INTEGER
 from utilities.utilities import _assert
+from utilities.utilities import add_postfix
+from utilities.utilities import extract_keyvalue_params
+from utilities.utilities import py_list_to_sql_string
 from utilities.utilities import unique_string
-from utilities.control import MinWarning
-from utilities.validate_args import quote_ident
-from utilities.validate_args import is_var_valid
 from utilities.utilities import NUMERIC, ONLY_ARRAY
 from utilities.utilities import is_valid_psql_type
 from utilities.utilities import is_pg_major_version_less_than
+from utilities.utilities import num_features
+from utilities.validate_args import array_col_has_no_null
+from utilities.validate_args import cols_in_tbl_valid
+from utilities.validate_args import drop_tables
+from utilities.validate_args import get_cols
+from utilities.validate_args import get_expr_type
+from utilities.validate_args import input_tbl_valid, output_tbl_valid
+from utilities.validate_args import is_col_array
+from utilities.validate_args import is_var_valid
+from utilities.validate_args import quote_ident
 
-MAX_WEIGHT_ZERO_DIST = 1e6
-
+WEIGHT_FOR_ZERO_DIST = 1e6
+BRUTE_FORCE = 'brute_force'
+KD_TREE = 'kd_tree'
 
 def knn_validate_src(schema_madlib, point_source, point_column_name, point_id,
                      label_column_name, test_source, test_column_name,
                      test_id, output_table, k, output_neighbors, fn_dist,
-                     **kwargs):
+                     is_brute_force, depth, leaf_nodes, **kwargs):
     input_tbl_valid(point_source, 'kNN')
     input_tbl_valid(test_source, 'kNN')
     output_tbl_valid(output_table, 'kNN')
@@ -60,14 +71,16 @@ def knn_validate_src(schema_madlib, point_source, point_column_name, point_id,
         cols_in_tbl_valid(point_source, [label_column_name], 'kNN')
 
     _assert(is_var_valid(point_source, point_column_name),
-            "kNN error: {0} is an invalid column name or expression for point_column_name param".format(point_column_name))
+            "kNN error: {0} is an invalid column name or "
+            "expression for point_column_name param".format(point_column_name))
     point_col_type = get_expr_type(point_column_name, point_source)
     _assert(is_valid_psql_type(point_col_type, NUMERIC | ONLY_ARRAY),
             "kNN Error: Feature column or expression '{0}' in train table is not"
             " an array.".format(point_column_name))
 
     _assert(is_var_valid(test_source, test_column_name),
-            "kNN error: {0} is an invalid column name or expression for test_column_name param".format(test_column_name))
+            "kNN error: {0} is an invalid column name or expression for "
+            "test_column_name param".format(test_column_name))
     test_col_type = get_expr_type(test_column_name, test_source)
     _assert(is_valid_psql_type(test_col_type, NUMERIC | ONLY_ARRAY),
             "kNN Error: Feature column or expression '{0}' in test table is not"
@@ -101,7 +114,7 @@ def knn_validate_src(schema_madlib, point_source, point_column_name, point_id,
                        format(col_type, point_source))
 
     col_type_test = get_expr_type(test_id, test_source).lower()
-    if col_type_test not in ['integer']:
+    if col_type_test not in INTEGER:
         plpy.error("kNN Error: Invalid data type '{0}' for"
                    " test_id column in table '{1}'.".
                    format(col_type_test, test_source))
@@ -113,27 +126,325 @@ def knn_validate_src(schema_madlib, point_source, point_column_name, point_id,
                                'squared_dist_norm2', 'dist_angle',
                                'dist_tanimoto')])
 
-        profunc = ("proisagg = TRUE" if is_pg_major_version_less_than(schema_madlib, 11)
-              else "prokind = 'a'")
+        profunc = ("proisagg = TRUE"
+                   if is_pg_major_version_less_than(schema_madlib, 11)
+                   else "prokind = 'a'")
 
         is_invalid_func = plpy.execute("""
-            SELECT prorettype != 'DOUBLE PRECISION'::regtype OR
-                   {profunc} AS OUTPUT
+            SELECT prorettype != 'DOUBLE PRECISION'::regtype OR {profunc} AS OUTPUT
             FROM pg_proc
             WHERE oid='{fn_dist}(DOUBLE PRECISION[], DOUBLE PRECISION[])'::regprocedure;
             """.format(fn_dist=fn_dist, profunc=profunc))[0]['output']
 
         if is_invalid_func or (fn_dist not in dist_functions):
-            plpy.error("KNN error: Distance function has invalid signature "
-                       "or is not a simple function.")
-
+            plpy.error("KNN error: Distance function ({0}) has invalid signature "
+                       "or is not a simple function.".format(fn_dist))
+    if not is_brute_force:
+        if depth <= 0:
+            plpy.error("kNN Error: depth={0} is an invalid value, must be "
+                       "greater than 0.".format(depth))
+        if leaf_nodes <= 0:
+            plpy.error("kNN Error: leaf_nodes={0} is an invalid value, must be "
+                       "greater than 0.".format(leaf_nodes))
+        if pow(2, depth) <= leaf_nodes:
+            plpy.error("kNN Error: depth={0}, leaf_nodes={1} is not valid. "
+                       "The leaf_nodes value must be lower than 2^depth".
+                       format(depth, leaf_nodes))
     return k
 # ------------------------------------------------------------------------------
 
 
+def build_kd_tree(schema_madlib, source_table, output_table, point_column_name,
+                  depth, r_id, **kwargs):
+    """
+        KD-tree function to create a partitioning for KNN
+        Args:
+            @param schema_madlib        Name of the Madlib Schema
+            @param source_table         Training data table
+            @param output_table         Name of the table to store kd tree
+            @param point_column_name    Name of the column with training data
+                                        or expression that evaluates to a
+                                        numeric array
+            @param depth                Depth of the kd tree
+            @param r_id                 Name of the region id column
+    """
+    with MinWarning("error"):
+
+        validate_kd_tree(source_table, output_table, point_column_name, depth)
+        n_features = num_features(source_table, point_column_name)
+
+        clauses = [' 1=1 ']
+        centers_table = add_postfix(output_table, "_centers")
+        clause_counter = 0
+        for curr_level in range(depth):
+            curr_feature = (curr_level % n_features) + 1
+            for curr_leaf in range(pow(2,curr_level)):
+                clause = clauses[clause_counter]
+                cutoff_sql = """
+                    SELECT percentile_disc(0.5)
+                           WITHIN GROUP (
+                            ORDER BY ({point_column_name})[{curr_feature}]
+                           ) AS cutoff
+                    FROM {source_table}
+                    WHERE {clause}
+                    """.format(**locals())
+
+                cutoff = plpy.execute(cutoff_sql)[0]['cutoff']
+                cutoff = "NULL" if cutoff is None else cutoff
+                clause_counter += 1
+                clauses.append(clause +
+                               "AND ({point_column_name})[{curr_feature}] < {cutoff} ".
+                               format(**locals()))
+                clauses.append(clause +
+                               "AND ({point_column_name})[{curr_feature}] >= {cutoff} ".
+                               format(**locals()))
+
+        n_leaves = pow(2, depth)
+        case_when_clause = '\n'.join(["WHEN {0} THEN {1}::INTEGER".format(cond, i)
+                                     for i, cond in enumerate(clauses[-n_leaves:])])
+        output_sql = """
+            CREATE TABLE {output_table} AS
+                SELECT *,
+                       CASE {case_when_clause} END AS {r_id}
+                FROM {source_table}
+            """.format(**locals())
+        plpy.execute(output_sql)
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(centers_table))
+        centers_sql = """
+            CREATE TABLE {centers_table} AS
+                SELECT {r_id}, {schema_madlib}.array_scalar_mult(
+                        {schema_madlib}.sum({point_column_name})::DOUBLE PRECISION[],
+                        (1.0/count(*))::DOUBLE PRECISION) AS __center__
+                FROM {output_table}
+                GROUP BY {r_id}
+            """.format(**locals())
+        plpy.execute(centers_sql)
+        return case_when_clause
+# ------------------------------------------------------------------------------
+
+
+def validate_kd_tree(source_table, output_table, point_column_name, depth):
+
+    input_tbl_valid(source_table, 'kd_tree')
+    output_tbl_valid(output_table, 'kd_tree')
+    output_tbl_valid(output_table+"_centers", 'kd_tree')
+
+    _assert(is_var_valid(source_table, point_column_name),
+            "kd_tree error: {0} is an invalid column name or expression for "
+            "point_column_name param".format(point_column_name))
+    point_col_type = get_expr_type(point_column_name, source_table)
+    _assert(is_valid_psql_type(point_col_type, NUMERIC | ONLY_ARRAY),
+            "kNN Error: Feature column or expression '{0}' in train table is not"
+            " an array.".format(point_column_name))
+    if depth <= 0:
+        plpy.error("kNN Error: depth={0} is an invalid value, must be greater "
+                   "than 0.".format(depth))
+# ------------------------------------------------------------------------------
+
+
+def knn_kd_tree(schema_madlib, kd_out, test_source, test_column_name, test_id,
+                fn_dist, max_leaves_to_explore, depth, r_id, case_when_clause,
+                t_col_name, **kwargs):
+    """
+        KNN function to find the K Nearest neighbours using kd tree
+        Args:
+            @param schema_madlib        Name of the Madlib Schema
+            @param kd_out               Name of the kd tree table
+            @param test_source          Name of the table containing the test
+                                        data points.
+            @param test_column_name     Name of the column with testing data
+                                        points or expression that evaluates to a
+                                        numeric array
+            @param test_id              Name of the column having ids of data
+                                        points in test data table.
+            @param fn_dist              Distance metrics function.
+            @param max_leaves_to_explore Number of leaf nodes to explore
+            @param depth                Depth of the kd tree
+            @param r_id                 Name of the region id column
+            @param case_when_clause     SQL string for reconstructing the
+                                        kd-tree
+            @param t_col_name           Unique test point column name
+    """
+    with MinWarning("error"):
+        centers_table = add_postfix(kd_out, "_centers")
+        test_view = add_postfix(kd_out, "_test_view")
+
+        n_leaves = pow(2,depth)
+        plpy.execute("DROP VIEW IF EXISTS {test_view}".format(**locals()))
+        test_view_sql = """
+            CREATE VIEW {test_view} AS
+                SELECT {test_id},
+                       ({test_column_name})::DOUBLE PRECISION[] AS {t_col_name},
+                       CASE
+                        {case_when_clause}
+                       END AS {r_id}
+                FROM {test_source}""".format(**locals())
+        plpy.execute(test_view_sql)
+
+        if max_leaves_to_explore > 1:
+            ext_test_view = add_postfix(kd_out, "_ext_test_view")
+            ext_test_view_sql = """
+                CREATE VIEW {ext_test_view} AS
+                SELECT * FROM(
+                    SELECT
+                        row_number() OVER (PARTITION BY {test_id}
+                                           ORDER BY __dist_center__) AS r,
+                        {test_id},
+                        {t_col_name},
+                        {r_id}
+                    FROM (
+                        SELECT
+                            {test_id},
+                            {t_col_name},
+                            {centers_table}.{r_id} AS {r_id},
+                            {fn_dist}({t_col_name}, __center__) AS __dist_center__
+                        FROM {test_view}, {centers_table}
+                    ) q1
+                ) q2
+                WHERE r <= {max_leaves_to_explore}
+            """.format(**locals())
+            plpy.execute(ext_test_view_sql)
+        else:
+            ext_test_view = test_view
+
+        return ext_test_view
+# ------------------------------------------------------------------------------
+
+def _create_interim_tbl(schema_madlib, point_source, point_column_name, point_id,
+    label_name, test_source, test_column_name, test_id, interim_table, k,
+    fn_dist, test_id_temp, train_id, dist_inverse, comma_label_out_alias,
+    label_out, r_id, kd_out, train, t_col_name, **kwargs):
+    """
+        KNN function to create the interim table
+        Args:
+            @param schema_madlib        Name of the Madlib Schema
+            @param point_source         Training data table
+            @param point_column_name    Name of the column with training data
+                                        or expression that evaluates to a
+                                        numeric array
+            @param point_id             Name of the column having ids of data
+                                        point in train data table
+                                        points.
+            @param label_name           Name of the column with labels/values
+                                        of training data points.
+            @param test_source          Name of the table containing the test
+                                        data points.
+            @param test_column_name     Name of the column with testing data
+                                        points or expression that evaluates to a
+                                        numeric array
+            @param test_id              Name of the column having ids of data
+                                        points in test data table.
+            @param interim_table        Name of the table to store interim
+                                        results.
+            @param k                    default: 1. Number of nearest
+                                        neighbors to consider
+            @param fn_dist              Distance metrics function. Default is
+                                        squared_dist_norm2. Following functions
+                                        are supported :
+                                        dist_norm1 , dist_norm2,squared_dist_norm2,
+                                        dist_angle , dist_tanimoto
+                                        Or user defined function with signature
+                                        DOUBLE PRECISION[] x, DOUBLE PRECISION[] y
+                                        -> DOUBLE PRECISION
+            Following parameters are passed to ensure the interim table has
+            identical features in both implementations
+            @param test_id_temp
+            @param train_id
+            @param dist_inverse
+            @param comma_label_out_alias
+            @param label_out
+            @param r_id
+            @param kd_out
+            @param train
+            @param t_col_name
+    """
+    with MinWarning("error"):
+        # If r_id is None, we are using the brute force algorithm.
+        is_brute_force = not bool(r_id)
+        r_id = "NULL AS {0}".format(unique_string()) if not r_id else r_id
+
+        p_col_name = unique_string(desp='p_col_name')
+        x_temp_table = unique_string(desp='x_temp_table')
+        y_temp_table = unique_string(desp='y_temp_table')
+        test = unique_string(desp='test')
+        r = unique_string(desp='r')
+        dist = unique_string(desp='dist')
+
+        if not is_brute_force:
+            point_source = kd_out
+            where_condition = "{train}.{r_id} = {test}.{r_id} ".format(**locals())
+            select_sql = """ {train}.{r_id} AS tr_{r_id},
+                            {test}.{r_id} AS test_{r_id}, """.format(**locals())
+            t_col_cast = t_col_name
+        else:
+            where_condition = "1 = 1"
+            select_sql = ""
+            t_col_cast = "({test_column_name}) AS {t_col_name}".format(**locals())
+
+        plpy.execute("""
+            CREATE TABLE {interim_table} AS
+                SELECT *
+                FROM (
+                    SELECT row_number() OVER
+                                (PARTITION BY {test_id_temp} ORDER BY {dist}) AS {r},
+                           {test_id_temp},
+                           {train_id},
+                           CASE WHEN {dist} = 0.0 THEN {weight_for_zero_dist}
+                                ELSE 1.0 / {dist}
+                           END AS {dist_inverse}
+                           {comma_label_out_alias}
+                    FROM (
+                        SELECT {select_sql}
+                               {test}.{test_id} AS {test_id_temp},
+                               {train}.{point_id} AS {train_id},
+                               {fn_dist}({p_col_name}, {t_col_name}) AS {dist}
+                               {label_out}
+                        FROM
+                            (
+                                SELECT {point_id},
+                                       {r_id},
+                                       {point_column_name} AS {p_col_name}
+                                       {label_name}
+                                FROM {point_source}
+                            ) {train},
+                            (
+                                SELECT {test_id},
+                                       {t_col_cast},
+                                       {r_id}
+                                FROM {test_source}
+                            ) {test}
+                        WHERE
+                            {where_condition}
+                    ) {x_temp_table}
+                ) {y_temp_table}
+            WHERE {y_temp_table}.{r} <= {k}
+            """.format(weight_for_zero_dist=WEIGHT_FOR_ZERO_DIST, **locals()))
+
+# ------------------------------------------------------------------------------
+
+def _get_algorithm_name(algorithm):
+    if not algorithm:
+        algorithm = BRUTE_FORCE
+    else:
+        supported_algorithms = [BRUTE_FORCE, KD_TREE]
+        try:
+            # allow user to specify a prefix substring of
+            # supported algorithms. This works because the supported
+            # algorithms have unique prefixes.
+            algorithm = next(x for x in supported_algorithms
+                               if x.startswith(algorithm))
+        except StopIteration:
+            # next() returns a StopIteration if no element found
+            plpy.error("kNN Error: Invalid algorithm: "
+                       "{0}. Supported algorithms are ({1})"
+                       .format(algorithm, ','.join(sorted(supported_algorithms))))
+    return algorithm
+# ------------------------------------------------------------------------------
+
 def knn(schema_madlib, point_source, point_column_name, point_id,
         label_column_name, test_source, test_column_name, test_id, output_table,
-        k, output_neighbors, fn_dist, weighted_avg, **kwargs):
+        k, output_neighbors, fn_dist, weighted_avg, algorithm, algorithm_params,
+        **kwargs):
     """
         KNN function to find the K Nearest neighbours
         Args:
@@ -158,7 +469,7 @@ def knn(schema_madlib, point_source, point_column_name, point_id,
                                         results.
             @param k                    default: 1. Number of nearest
                                         neighbors to consider
-            @output_neighbours          Outputs the list of k-nearest neighbors
+            @param output_neighbours    Outputs the list of k-nearest neighbors
                                         that were used in the voting/averaging.
             @param fn_dist              Distance metrics function. Default is
                                         squared_dist_norm2. Following functions
@@ -166,32 +477,52 @@ def knn(schema_madlib, point_source, point_column_name, point_id,
                                         dist_norm1 , dist_norm2,squared_dist_norm2,
                                         dist_angle , dist_tanimoto
                                         Or user defined function with signature
-                                        DOUBLE PRECISION[] x, DOUBLE PRECISION[] y -> DOUBLE PRECISION
-            @param weighted_avg         Calculates the Regression or classication of k-NN using
+                                        DOUBLE PRECISION[] x, DOUBLE PRECISION[] y
+                                        -> DOUBLE PRECISION
+            @param weighted_avg         Calculates the Regression or
+                                        classication of k-NN using
                                         the weighted average method.
+            @param algorithm            The algorithm to use for knn
+            @param algorithm_params     The parameters for kd-tree algorithm
     """
     with MinWarning('warning'):
         output_neighbors = True if output_neighbors is None else output_neighbors
         if k is None:
             k = 1
+
+        algorithm = _get_algorithm_name(algorithm)
+
+        # Default values for depth and leaf nodes
+        depth = 3
+        max_leaves_to_explore = 2
+
+        if algorithm_params:
+            params_types = {'depth': int, 'leaf_nodes': int}
+            default_args = {'depth': 3, 'leaf_nodes': 2}
+            algorithm_params_dict = extract_keyvalue_params(algorithm_params,
+                                                            params_types,
+                                                            default_args)
+
+            depth = algorithm_params_dict['depth']
+            max_leaves_to_explore = algorithm_params_dict['leaf_nodes']
+
         knn_validate_src(schema_madlib, point_source,
                          point_column_name, point_id, label_column_name,
                          test_source, test_column_name, test_id,
-                         output_table, k, output_neighbors, fn_dist)
+                         output_table, k, output_neighbors, fn_dist,
+                         algorithm == BRUTE_FORCE, depth, max_leaves_to_explore)
+
+        n_features = num_features(test_source, test_column_name)
 
         # Unique Strings
-        x_temp_table = unique_string(desp='x_temp_table')
-        y_temp_table = unique_string(desp='y_temp_table')
         label_col_temp = unique_string(desp='label_col_temp')
         test_id_temp = unique_string(desp='test_id_temp')
+
         train = unique_string(desp='train')
-        test = unique_string(desp='test')
-        p_col_name = unique_string(desp='p_col_name')
-        t_col_name = unique_string(desp='t_col_name')
-        dist = unique_string(desp='dist')
         train_id = unique_string(desp='train_id')
         dist_inverse = unique_string(desp='dist_inverse')
-        r = unique_string(desp='r')
+        dim = unique_string(desp='dim')
+        t_col_name = unique_string(desp='t_col_name')
 
         if not fn_dist:
             fn_dist = '{0}.squared_dist_norm2'.format(schema_madlib)
@@ -206,6 +537,9 @@ def knn(schema_madlib, point_source, point_column_name, point_id,
         view_def = ""
         view_join = ""
         view_grp_by = ""
+        r_id = None
+        kd_output_table = None
+        test_data = None
 
         if label_column_name:
             label_column_type = get_expr_type(
@@ -240,11 +574,10 @@ def knn(schema_madlib, point_source, point_column_name, point_id,
                                     {test_id_temp},
                                     {label_col_temp},
                                     sum({dist_inverse}) data_sum
-                                FROM pg_temp.{interim_table}
+                                FROM {interim_table}
                                 GROUP BY {test_id_temp},
                                          {label_col_temp}
                             ) a
-                            -- GROUP BY {test_id_temp} , {label_col_temp}
                         )
                         """.format(**locals())
                     # This join is needed to get the max value of predicion
@@ -276,44 +609,34 @@ def knn(schema_madlib, point_source, point_column_name, point_id,
             comma_label_out_alias = ""
             label_name = ""
 
-        # interim_table picks the 'k' nearest neighbors for each test point
         if output_neighbors:
             knn_neighbors = (", array_agg(knn_temp.{train_id} ORDER BY "
                              "knn_temp.{dist_inverse} DESC) AS k_nearest_neighbours ").format(**locals())
         else:
             knn_neighbors = ''
-        plpy.execute("""
-            CREATE TEMP TABLE {interim_table} AS
-                SELECT * FROM (
-                    SELECT row_number() over
-                            (partition by {test_id_temp} order by {dist}) AS {r},
-                            {test_id_temp},
-                            {train_id},
-                            CASE WHEN {dist} = 0.0 THEN {max_weight_zero_dist}
-                                 ELSE 1.0 / {dist}
-                            END AS {dist_inverse}
-                            {comma_label_out_alias}
-                    FROM (
-                        SELECT {test}.{test_id} AS {test_id_temp},
-                            {train}.{point_id} as {train_id},
-                            {fn_dist}(
-                                {p_col_name},
-                                {t_col_name})
-                            AS {dist}
-                            {label_out}
-                            FROM
-                            (
-                            SELECT {point_id} , {point_column_name} as {p_col_name} {label_name} from {point_source}
-                            ) {train},
-                            (
-                            SELECT {test_id} ,{test_column_name} as {t_col_name} from {test_source}
-                            ) {test}
-                        ) {x_temp_table}
-                    ) {y_temp_table}
-            WHERE {y_temp_table}.{r} <= {k}
-            """.format(max_weight_zero_dist=MAX_WEIGHT_ZERO_DIST, **locals()))
 
-        sql = """
+        if 'kd_tree' in algorithm:
+            r_id = unique_string(desp='r_id')
+            kd_output_table = unique_string(desp='kd_tree')
+            case_when_clause = build_kd_tree(schema_madlib,
+                                             point_source,
+                                             kd_output_table,
+                                             point_column_name,
+                                             depth, r_id)
+            test_data = knn_kd_tree(schema_madlib, kd_output_table, test_source,
+                                    test_column_name, test_id, fn_dist,
+                                    max_leaves_to_explore, depth, r_id,
+                                    case_when_clause, t_col_name)
+        else:
+            test_data = test_source
+
+        # interim_table picks the 'k' nearest neighbors for each test point
+        _create_interim_tbl(schema_madlib, point_source, point_column_name,
+                            point_id, label_name, test_data, test_column_name,
+                            test_id, interim_table, k, fn_dist, test_id_temp,
+                            train_id, dist_inverse, comma_label_out_alias,
+                            label_out, r_id, kd_output_table, train, t_col_name)
+        output_sql = """
             CREATE TABLE {output_table} AS
                 {view_def}
                 SELECT
@@ -322,7 +645,7 @@ def knn(schema_madlib, point_source, point_column_name, point_id,
                     {pred_out}
                     {knn_neighbors}
                 FROM
-                    pg_temp.{interim_table}  AS knn_temp
+                    {interim_table}  AS knn_temp
                     JOIN
                     {test_source} AS knn_test
                 ON knn_temp.{test_id_temp} = knn_test.{test_id}
@@ -330,11 +653,19 @@ def knn(schema_madlib, point_source, point_column_name, point_id,
                 GROUP BY knn_temp.{test_id_temp},
                     {test_column_name}
                          {view_grp_by}
-            """
-        plpy.execute(sql.format(**locals()))
-        plpy.execute("DROP TABLE IF EXISTS {0}".format(interim_table))
+            """.format(**locals())
+        plpy.execute(output_sql)
+        drop_tables([interim_table])
+
+        if 'kd_tree' in algorithm:
+            centers_table = add_postfix(kd_output_table, "_centers")
+            test_view = add_postfix(kd_output_table, "_test_view")
+            ext_test_view = add_postfix(kd_output_table, "_ext_test_view")
+            plpy.execute("DROP VIEW IF EXISTS {0} CASCADE".format(test_view))
+            plpy.execute("DROP VIEW IF EXISTS {0} CASCADE".format(ext_test_view))
+            drop_tables([centers_table, kd_output_table])
         return
-
+# ------------------------------------------------------------------------------
 
 def knn_help(schema_madlib, message, **kwargs):
     """
@@ -366,7 +697,9 @@ SELECT {schema_madlib}.knn(
     k,                  -- value of k. Default will go as 1
     output_neighbors    -- Outputs the list of k-nearest neighbors that were used in the voting/averaging.
     fn_dist             -- The name of the function to use to calculate the distance from a data point to a centroid.
-    weighted_avg         Calculates the Regression or classication of k-NN using the weighted average method.
+    weighted_avg        -- Calculates the Regression or classication of k-NN using the weighted average method.
+    algorithm           -- The algorithm to use for knn.
+     algorithm_params   -- The parameters for kd-tree algorithm.
     );
 
 -----------------------------------------------------------------------
@@ -397,6 +730,5 @@ of k nearest neighbors of the given testing example.
 For an overview on usage, run:
 SELECT {schema_madlib}.knn('usage');
 """
-
     return help_string.format(schema_madlib=schema_madlib)
 # ------------------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/knn.sql_in b/src/ports/postgres/modules/knn/knn.sql_in
index 6fe1672..0693e94 100644
--- a/src/ports/postgres/modules/knn/knn.sql_in
+++ b/src/ports/postgres/modules/knn/knn.sql_in
@@ -32,7 +32,6 @@
 
 m4_include(`SQLCommon.m4')
 
-
 /**
 @addtogroup grp_knn
 
@@ -47,24 +46,22 @@ m4_include(`SQLCommon.m4')
 </ul>
 </div>
 
-@brief Finds k nearest data points to the given data point and outputs majority vote value of output classes for classification, and average value of target values for regression.
-
-\warning <em> This MADlib method is still in early stage development. There may be some
-issues that will be addressed in a future version. Interface and implementation
-are subject to change. </em>
+@brief Finds \f$k\f$ nearest data points to the given data point and outputs majority
+vote value of output classes for classification, or average value of target
+values for regression.
 
 @anchor knn
 
-K-nearest neighbors is a method for finding the k closest points to a
-given data point in terms of a given metric. Its input consists of
-data points as features from testing examples, and it
-looks for k closest points in the training set for each of the data
-points in test set.  The output of KNN depends on the type of task.
-For classification, the output is the majority vote of the classes of
-the k nearest data points. That is, the testing example gets assigned the
-most popular class from the nearest neighbors.
-For regression, the output is the average of the values of k nearest
-neighbors of the given test point.
+K-nearest neighbors is a method for finding the \f$k\f$ closest points to a given data
+point in terms of a given metric. Its input consists of data points as features
+from testing examples and it looks for \f$k\f$ closest points in the training set
+for each of the data points in test set. The output of KNN depends on the type
+of task. For classification, the output is the majority vote of the classes of
+the \f$k\f$ nearest data points. For regression, the output is the average of the
+values of \f$k\f$ nearest neighbors of the given test point.
+
+Both exact and approximate methods are supported. The approximate methods can be
+used in the case that run-time is too long using the exact method.
 
 @anchor usage
 @par Usage
@@ -80,7 +77,9 @@ knn( point_source,
      k,
      output_neighbors,
      fn_dist,
-     weighted_avg
+     weighted_avg,
+     algorithm,
+     algorithm_params
    )
 </pre>
 
@@ -93,7 +92,7 @@ in a column of type <tt>DOUBLE PRECISION[]</tt>.
 </dd>
 
 <dt>point_column_name</dt>
-<dd>TEXT. Name of the column with training data points 
+<dd>TEXT. Name of the column with training data points
 or expression that evaluates to a numeric array</dd>
 
 <dt>point_id</dt>
@@ -116,7 +115,7 @@ in a column of type <tt>DOUBLE PRECISION[]</tt>.
 </dd>
 
 <dt>test_column_name</dt>
-<dd>TEXT. Name of the column with testing data points 
+<dd>TEXT. Name of the column with testing data points
 or expression that evaluates to a numeric array</dd>
 
 <dt>test_id</dt>
@@ -154,13 +153,38 @@ regression values using a weighted average.   The idea is to
 weigh the contribution of each of the k neighbors according
 to their distance to the test point, giving greater influence
 to closer neighbors.  The distance function 'fn_dist' specified
-above is used.
+above is used.  For classification, majority voting weighs a neighbor
+according to inverse distance.  For regression, the inverse distance
+weighting approach is used from Shepard [4].
+
+<dt>algorithm (optional)</dt>
+<dd>TEXT, default: 'brute_force'. The name of the algorithm
+used to compute nearest neighbors. The following options are supported:
+<ul>
+<li><b>\ref brute_force</b>: Produces an exact result by searching
+all points in the search space.  You can also use a short
+form "b" or "brute" etc. to select brute force.</li>
+<li><b>\ref kd_tree</b>: Produces an approximate result by searching
+a subset of the search space, that is, only certain leaf nodes in the
+kd-tree as specified by "algorithm_params" below.
+You can also use a short
+form "k" or "kd" etc. to select kd-tree.</li></ul></dd>
+
+<dt>algorithm_params (optional)</dt>
+<dd>TEXT, default: 'depth=3, leaf_nodes=2'. These parameters apply to the
+kd-tree algorithm only.
+<ul>
+<li><b>\ref depth</b>: Depth of the kd-tree. Increasing this value will
+decrease run-time but reduce the accuracy.</li>
+<li><b>\ref leaf_nodes</b>: Number of leaf nodes (regions) to search for each test point.
+Inceasing this value will improve the accuracy but increase run-time.</li></ul>
 
-For classification, majority voting weighs a neighbor
-according to inverse distance.
+@note
+Please note that the kd-tree accuracy will be lower for datasets with a high
+number of features. It is advised to use at least two leaf nodes.
+Refer to the <a href="#background">Technical Background</a> for more information
+on how the kd-tree is implemented.</dd>
 
-For regression, the inverse distance weighting approach is
-used from Shepard [4].
 </dl>
 
 
@@ -234,7 +258,7 @@ INSERT INTO knn_train_data_reg VALUES
 
 -#  Prepare some testing data:
 <pre class="example">
-DROP TABLE IF EXISTS knn_test_data;
+DROP TABLE IF EXISTS knn_test_data CASCADE;
 CREATE TABLE knn_test_data (
                     id integer,
                     data integer[]
@@ -375,6 +399,73 @@ SELECT * FROM knn_result_classification ORDER BY id;
 (6 rows)
 </pre>
 
+-# Use kd-tree option.  First we build a kd-tree to depth 4 and
+search half (8) of the 16 leaf nodes (i.e., 2^4 total leaf nodes):
+<pre class="example">
+DROP TABLE IF EXISTS knn_result_classification_kd;
+SELECT madlib.knn(
+                'knn_train_data',        -- Table of training data
+                'data',                  -- Col name of training data
+                'id',                    -- Col name of id in train data
+                NULL,                    -- Training labels
+                'knn_test_data',         -- Table of test data
+                'data',                  -- Col name of test data
+                'id',                    -- Col name of id in test data
+                'knn_result_classification_kd',  -- Output table
+                 3,                      -- Number of nearest neighbors
+                 True,                   -- True to list nearest-neighbors by id
+                 'madlib.squared_dist_norm2', -- Distance function
+                 False,                  -- For weighted average
+                 'kd_tree',              -- Use kd-tree
+                 'depth=4, leaf_nodes=8' -- Kd-tree options
+                 );
+SELECT * FROM knn_result_classification_kd ORDER BY id;
+</pre>
+<pre class="result">
+ id |  data   | k_nearest_neighbours
+----+---------+----------------------
+  1 | {2,1}   | {1,2,3}
+  2 | {2,6}   | {5,4,3}
+  3 | {15,40} | {7,6,5}
+  4 | {12,1}  | {4,5,3}
+  5 | {2,90}  | {9,6,7}
+  6 | {50,45} | {6,7,8}
+(6 rows)
+</pre>
+The result above is the same as brute force. If we search just 1 leaf node,
+run-time will be faster but accuracy will be lower. This shows up in this
+very small data set by not being able to find 3 nearest neighbors for all test points:
+<pre class="example">
+DROP TABLE IF EXISTS knn_result_classification_kd;
+SELECT madlib.knn(
+                'knn_train_data',        -- Table of training data
+                'data',                  -- Col name of training data
+                'id',                    -- Col name of id in train data
+                NULL,                    -- Training labels
+                'knn_test_data',         -- Table of test data
+                'data',                  -- Col name of test data
+                'id',                    -- Col name of id in test data
+                'knn_result_classification_kd',  -- Output table
+                 3,                      -- Number of nearest neighbors
+                 True,                   -- True to list nearest-neighbors by id
+                 'madlib.squared_dist_norm2', -- Distance function
+                 False,                  -- For weighted average
+                 'kd_tree',              -- Use kd-tree
+                 'depth=4, leaf_nodes=1' -- Kd-tree options
+                 );
+SELECT * FROM knn_result_classification_kd ORDER BY id;
+</pre>
+<pre class="result">
+ id |  data   | k_nearest_neighbours
+----+---------+----------------------
+  1 | {2,1}   | {1}
+  2 | {2,6}   | {3,2}
+  3 | {15,40} | {7}
+  5 | {2,90}  | {3,2}
+  6 | {50,45} | {6,8}
+(5 rows)
+</pre>
+
 @anchor background
 @par Technical Background
 
@@ -382,11 +473,37 @@ The training data points are vectors in a multidimensional feature space,
 each with a class label. The training phase of the algorithm consists
 only of storing the feature vectors and class labels of the training points.
 
-In the classification phase, k is a user-defined constant, and an unlabeled
-vector (a test point) is classified by assigning the label which is most
-frequent among the k training samples nearest to that test point.
-In case of regression, average of the values of these k training samples
-is assigned to the test point.
+In the prediction phase, \f$k\f$ is a user-defined constant, and an unlabeled vector
+(a test point) is predicted by using the label from the the \f$k\f$ training samples
+nearest to that test point.
+
+Since distances between points are used to find the nearest neighbors, the data
+should be standardized across features. This ensures that all features are given
+equal weightage in the distance computation.
+
+An approximation method can be used to speed the prediction phase by building
+appropriate data structures in the training phase. An example of such a data
+structure is kd-trees [5]. Using the kd-tree algorithm can improve the execution
+time of the \f$k\f$-NN operation, but at expense of sacrificing some accuracy. The
+kd-tree implementation divides the training dataset into multiple regions that
+correspond to the leaf nodes of a tree. For example, a tree of depth \f$3\f$ will have
+a total of \f$2^3 = 8\f$ regions. The algorithm will look for the nearest neighbors
+in a subset of all regions instead of searching the whole dataset. For a given
+test point, the first (home) region is found by traversing the tree and finding
+its associated node. If the user requests additional leaf nodes to be searched,
+we look at the distance between the point and the centroids of other regions and
+expand the search to the specified number of closest regions.
+
+It's important to note that the nodes that each level of the kd-tree search over
+a single feature and the features are explored in the same order as that in the
+data.
+
+The kd-tree accuracy might suffer on datasets with a high number of features
+(dimensions). For example, let's say we are using a dataset with 20 features and
+kd-tree depth of only 3. This means the kd-tree is constructed based on the
+first 3 features. Therefore, it is possible to miss nearest neighbors that are
+closer in those 17 dimensions because they got assigned to a further region (the
+distance computation would still uses all 20 features).
 
 @anchor literature
 @literature
@@ -404,14 +521,20 @@ is assigned to the test point.
     https://ai2-s2-pdfs.s3.amazonaws.com/a7e2/814ec5db800d2f8c4313fd436e9cf8273821.pdf
 
 @anchor knn-lit-4
-[4]    Shepard, Donald (1968). "A two-dimensional interpolation function for
+[4] Shepard, Donald (1968). "A two-dimensional interpolation function for
 irregularly-spaced data". Proceedings of the 1968 ACM National Conference. pp. 517–524.
 
+@anchor knn-lit-5
+[5] Bentley, J. L. (1975). "Multidimensional binary search trees used for
+associative searching". Communications of the ACM. 18 (9): 509. doi:10.1145/361002.361007.
+
+
 @internal
 @sa namespace knn (documenting the implementation in Python)
 @endinternal
 */
 
+
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.__knn_validate_src(
     point_source VARCHAR,
     point_column_name VARCHAR,
@@ -440,12 +563,61 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
     k INTEGER,
     output_neighbors BOOLEAN,
     fn_dist TEXT,
-    weighted_avg BOOLEAN
+    weighted_avg BOOLEAN,
+    algorithm VARCHAR,
+    algorithm_params VARCHAR
 ) RETURNS VARCHAR AS $$
     PythonFunction(`knn', `knn', `knn')
 $$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
+    point_source VARCHAR,
+    point_column_name VARCHAR,
+    point_id VARCHAR,
+    label_column_name VARCHAR,
+    test_source VARCHAR,
+    test_column_name VARCHAR,
+    test_id VARCHAR,
+    output_table VARCHAR,
+    k INTEGER,
+    output_neighbors BOOLEAN,
+    fn_dist TEXT,
+    weighted_avg BOOLEAN,
+    algorithm VARCHAR
+) RETURNS VARCHAR AS $$
+    DECLARE
+    returnstring VARCHAR;
+BEGIN
+    returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,
+                                     NULL);
+    RETURN returnstring;
+END;
+$$ LANGUAGE plpgsql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
+    point_source VARCHAR,
+    point_column_name VARCHAR,
+    point_id VARCHAR,
+    label_column_name VARCHAR,
+    test_source VARCHAR,
+    test_column_name VARCHAR,
+    test_id VARCHAR,
+    output_table VARCHAR,
+    k INTEGER,
+    output_neighbors BOOLEAN,
+    fn_dist TEXT,
+    weighted_avg BOOLEAN
+) RETURNS VARCHAR AS $$
+    DECLARE
+    returnstring VARCHAR;
+BEGIN
+    returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,
+                                     NULL, NULL);
+    RETURN returnstring;
+END;
+$$ LANGUAGE plpgsql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
     point_source VARCHAR,
@@ -463,7 +635,8 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
     DECLARE
     returnstring VARCHAR;
 BEGIN
-    returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11, FALSE);
+    returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,
+                                     FALSE, NULL, NULL);
     RETURN returnstring;
 END;
 $$ LANGUAGE plpgsql VOLATILE
@@ -486,7 +659,8 @@ DECLARE
     returnstring VARCHAR;
 BEGIN
     returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,
-                                     'MADLIB_SCHEMA.squared_dist_norm2', FALSE);
+                                     'MADLIB_SCHEMA.squared_dist_norm2', FALSE,
+                                     NULL, NULL);
     RETURN returnstring;
 END;
 $$ LANGUAGE plpgsql VOLATILE
@@ -507,7 +681,8 @@ DECLARE
     returnstring VARCHAR;
 BEGIN
     returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,$9,TRUE,
-                                     'MADLIB_SCHEMA.squared_dist_norm2', FALSE);
+                                     'MADLIB_SCHEMA.squared_dist_norm2', FALSE,
+                                     NULL, NULL);
     RETURN returnstring;
 END;
 $$ LANGUAGE plpgsql VOLATILE
@@ -527,7 +702,8 @@ DECLARE
     returnstring VARCHAR;
 BEGIN
     returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,1,TRUE,
-                                     'MADLIB_SCHEMA.squared_dist_norm2',FALSE);
+                                     'MADLIB_SCHEMA.squared_dist_norm2',FALSE,
+                                     NULL, NULL);
     RETURN returnstring;
 END;
 $$ LANGUAGE plpgsql VOLATILE
@@ -546,4 +722,3 @@ RETURNS VARCHAR AS $$
     SELECT MADLIB_SCHEMA.knn('');
 $$ LANGUAGE sql IMMUTABLE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `CONTAINS SQL', `');
-
diff --git a/src/ports/postgres/modules/knn/test/knn.sql_in b/src/ports/postgres/modules/knn/test/knn.sql_in
index 6dbed36..86f0eb4 100644
--- a/src/ports/postgres/modules/knn/test/knn.sql_in
+++ b/src/ports/postgres/modules/knn/test/knn.sql_in
@@ -8,7 +8,7 @@
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
  *
- *   http://www.apache.org/licenses/LICENSE-2.0
+ *  http://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
@@ -25,12 +25,12 @@ m4_include(`SQLCommon.m4')
  *
  * -------------------------------------------------------------------------- */
 
-drop table if exists knn_train_data;
-create table knn_train_data (
+DROP TABLE if exists knn_train_data;
+create TABLE knn_train_data (
 id  integer,
 data    integer[],
 label   integer);
-copy knn_train_data (id, data, label) from stdin delimiter '|';
+copy knn_train_data (id, data, label) FROM stdin delimiter '|';
 1|{1,1}|1
 2|{2,2}|1
 3|{3,3}|1
@@ -47,7 +47,7 @@ CREATE TABLE knn_train_data_reg (
                     data integer[],
                     label float
                     );
-COPY knn_train_data_reg (id, data, label) from stdin delimiter '|';
+COPY knn_train_data_reg (id, data, label) FROM stdin delimiter '|';
 1|{1,1}|1.0
 2|{2,2}|1.0
 3|{3,3}|1.0
@@ -59,10 +59,10 @@ COPY knn_train_data_reg (id, data, label) from stdin delimiter '|';
 9|{1,111}|0.0
 \.
 DROP TABLE IF EXISTS knn_test_data;
-create table knn_test_data (
+create TABLE knn_test_data (
 id  integer,
 data integer[]);
-copy knn_test_data (id, data) from stdin delimiter '|';
+copy knn_test_data (id, data) FROM stdin delimiter '|';
 1|{2,1}
 2|{2,6}
 3|{15,40}
@@ -70,13 +70,13 @@ copy knn_test_data (id, data) from stdin delimiter '|';
 5|{2,90}
 6|{50,45}
 \.
-drop table if exists knn_train_data_expr;
-create table knn_train_data_expr (
+DROP TABLE if exists knn_train_data_expr;
+create TABLE knn_train_data_expr (
 id  integer,
-data1    integer,
+data1   integer,
 data2    integer,
 label   integer);
-copy knn_train_data_expr (id, data1 , data2, label) from stdin delimiter '|';
+copy knn_train_data_expr (id, data1 , data2, label) FROM stdin delimiter '|';
 1| 1  |  1  |1
 2| 2  |  2  |1
 3| 3  |  3  |1
@@ -90,77 +90,242 @@ copy knn_train_data_expr (id, data1 , data2, label) from stdin delimiter '|';
 
 
 
-drop table if exists madlib_knn_result_classification;
-select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2',False);
-select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
+DROP TABLE if exists madlib_knn_result_classification;
+SELECT knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2',False);
+SELECT assert(array_agg(prediction ORDER BY id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') FROM madlib_knn_result_classification;
 
-drop table if exists madlib_knn_result_classification;
-select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3);
-select assert(array_agg(x order by id)= '{1,2,3}','Wrong output in classification with k=3') from (select unnest(k_nearest_neighbours) as x, id from madlib_knn_result_classification where id = 1 order by x asc) y;
+DROP TABLE if exists madlib_knn_result_classification;
+SELECT knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3);
+SELECT assert(array_agg(x ORDER BY id)= '{1,2,3}','Wrong output in classification with k=3') FROM (SELECT unnest(k_nearest_neighbours) AS x, id FROM madlib_knn_result_classification WHERE id = 1 ORDER BY x ASC) y;
 
-drop table if exists madlib_knn_result_regression;
-select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'MADLIB_SCHEMA.squared_dist_norm2',False);
-select assert(array_agg(prediction order by id)='{1,1,0.5,1,0.25,0.25}', 'Wrong output in regression') from madlib_knn_result_regression;
+DROP TABLE if exists madlib_knn_result_regression;
+SELECT knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'MADLIB_SCHEMA.squared_dist_norm2',False);
+SELECT assert(array_agg(prediction ORDER BY id)='{1,1,0.5,1,0.25,0.25}', 'Wrong output in regression') FROM madlib_knn_result_regression;
 
-drop table if exists madlib_knn_result_regression;
-select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',3,True);
-select assert(array_agg(x order by id)= '{1,2,3}' , 'Wrong output in regression with k=3') from (select unnest(k_nearest_neighbours) as x, id from madlib_knn_result_regression where id = 1 order by x asc) y;
+DROP TABLE if exists madlib_knn_result_regression;
+SELECT knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',3,True);
+SELECT assert(array_agg(x ORDER BY id)= '{1,2,3}' , 'Wrong output in regression with k=3') FROM (SELECT unnest(k_nearest_neighbours) AS x, id FROM madlib_knn_result_regression WHERE id = 1 ORDER BY x ASC) y;
 
-drop table if exists madlib_knn_result_classification;
-select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,NULL,False);
-select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
+DROP TABLE if exists madlib_knn_result_classification;
+SELECT knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,NULL,False);
+SELECT assert(array_agg(prediction ORDER BY id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') FROM madlib_knn_result_classification;
 
-drop table if exists madlib_knn_result_classification;
-select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.dist_norm1');
-select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
+DROP TABLE if exists madlib_knn_result_classification;
+SELECT knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.dist_norm1');
+SELECT assert(array_agg(prediction ORDER BY id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') FROM madlib_knn_result_classification;
 
-drop table if exists madlib_knn_result_classification;
-select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.dist_angle');
-select assert(array_agg(prediction order by id)='{1,0,0,1,0,1}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
+DROP TABLE if exists madlib_knn_result_classification;
+SELECT knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.dist_angle');
+SELECT assert(array_agg(prediction ORDER BY id)='{1,0,0,1,0,1}', 'Wrong output in classification with k=3') FROM madlib_knn_result_classification;
 
-drop table if exists madlib_knn_result_classification;
-select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.dist_tanimoto');
-select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
+DROP TABLE if exists madlib_knn_result_classification;
+SELECT knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.dist_tanimoto');
+SELECT assert(array_agg(prediction ORDER BY id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') FROM madlib_knn_result_classification;
 
-drop table if exists madlib_knn_result_regression;
-select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'MADLIB_SCHEMA.dist_norm1');
-select assert(array_agg(prediction order by id)='{1,1,0.5,1,0.25,0.25}', 'Wrong output in regression') from madlib_knn_result_regression;
+DROP TABLE if exists madlib_knn_result_regression;
+SELECT knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'MADLIB_SCHEMA.dist_norm1');
+SELECT assert(array_agg(prediction ORDER BY id)='{1,1,0.5,1,0.25,0.25}', 'Wrong output in regression') FROM madlib_knn_result_regression;
 
-drop table if exists madlib_knn_result_regression;
-select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'MADLIB_SCHEMA.dist_angle');
-select assert(array_agg(prediction order by id)='{0.75,0.25,0.25,0.75,0.25,1}', 'Wrong output in regression') from madlib_knn_result_regression;
+DROP TABLE if exists madlib_knn_result_regression;
+SELECT knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'MADLIB_SCHEMA.dist_angle');
+SELECT assert(array_agg(prediction ORDER BY id)='{0.75,0.25,0.25,0.75,0.25,1}', 'Wrong output in regression') FROM madlib_knn_result_regression;
 
 
-drop table if exists madlib_knn_result_classification;
-select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2', True);
-select assert(array_agg(prediction::numeric order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
+DROP TABLE if exists madlib_knn_result_classification;
+SELECT knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2', True);
+SELECT assert(array_agg(prediction::numeric ORDER BY id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') FROM madlib_knn_result_classification;
 
 
-drop table if exists madlib_knn_result_regression;
-select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',3,False,'MADLIB_SCHEMA.squared_dist_norm2', True);
-select assert(array_agg(prediction::numeric order by id)='{1,1,0.0408728591876018,1,0,0}', 'Wrong output in regression') from madlib_knn_result_regression;
+DROP TABLE if exists madlib_knn_result_regression;
+SELECT knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',3,False,'MADLIB_SCHEMA.squared_dist_norm2', True);
+SELECT assert(array_agg(prediction::numeric ORDER BY id)='{1,1,0.0408728591876018,1,0,0}', 'Wrong output in regression') FROM madlib_knn_result_regression;
 
 
-drop table if exists madlib_knn_result_classification;
-select knn('knn_train_data','data[1:1]||data[2:2]','id','label','knn_test_data','data[1:1]||data[2:2]','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2', True);
-select assert(array_agg(prediction::numeric order by id)='{1,1,0,1,0,0}', 'Wrong output in classification') from madlib_knn_result_classification;
+DROP TABLE if exists madlib_knn_result_classification;
+SELECT knn('knn_train_data','data[1:1]||data[2:2]','id','label','knn_test_data','data[1:1]||data[2:2]','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2', True);
+SELECT assert(array_agg(prediction::numeric ORDER BY id)='{1,1,0,1,0,0}', 'Wrong output in classification') FROM madlib_knn_result_classification;
 
-drop table if exists madlib_knn_result_regression;
-select knn('knn_train_data_reg','data[1:1]||data[2:2]','id','label','knn_test_data','data[1:1]||data[2:2]','id','madlib_knn_result_regression',3,False,'MADLIB_SCHEMA.squared_dist_norm2', True);
-select assert(array_agg(prediction::numeric order by id)='{1,1,0.0408728591876018,1,0,0}', 'Wrong output in regression') from madlib_knn_result_regression;
+DROP TABLE if exists madlib_knn_result_regression;
+SELECT knn('knn_train_data_reg','data[1:1]||data[2:2]','id','label','knn_test_data','data[1:1]||data[2:2]','id','madlib_knn_result_regression',3,False,'MADLIB_SCHEMA.squared_dist_norm2', True);
+SELECT assert(array_agg(prediction::numeric ORDER BY id)='{1,1,0.0408728591876018,1,0,0}', 'Wrong output in regression') FROM madlib_knn_result_regression;
 
 
 
 
-drop table if exists madlib_knn_result_classification;
-select knn('knn_train_data_expr','ARRAY[data1,data2]','id','label','knn_test_data','data[1:1]||data[2:2]','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2', True);
-select assert(array_agg(prediction::numeric order by id)='{1,1,0,1,0,0}', 'Wrong output in classification') from madlib_knn_result_classification;
+DROP TABLE if exists madlib_knn_result_classification;
+SELECT knn('knn_train_data_expr','ARRAY[data1,data2]','id','label','knn_test_data','data[1:1]||data[2:2]','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2', True);
+SELECT assert(array_agg(prediction::numeric ORDER BY id)='{1,1,0,1,0,0}', 'Wrong output in classification') FROM madlib_knn_result_classification;
 
 
 
-drop table if exists madlib_knn_result_classification;
-select knn('knn_train_data','data','id',NULL,'knn_test_data','data','id','madlib_knn_result_classification',3);
-select assert(array_agg(x order by id)= '{1,2,3}','Wrong output in classification with k=3') from (select unnest(k_nearest_neighbours) as x, id from madlib_knn_result_classification where id = 1 order by x asc) y;
+DROP TABLE if exists madlib_knn_result_classification;
+SELECT knn('knn_train_data','data','id',NULL,'knn_test_data','data','id','madlib_knn_result_classification',3);
+SELECT assert(array_agg(x ORDER BY id)= '{1,2,3}','Wrong output in classification with k=3') FROM (SELECT unnest(k_nearest_neighbours) AS x, id FROM madlib_knn_result_classification WHERE id = 1 ORDER BY x ASC) y;
 
-select knn();
-select knn('help');
+SELECT knn();
+SELECT knn('help');
+
+
+
+DROP TABLE if exists knn_train_data2;
+CREATE TABLE knn_train_data2 (
+    id integer,
+    data double precision[],
+    label integer
+);
+COPY knn_train_data2 (id, data, label) FROM stdin delimiter '|';
+1|{43983,164834}|0
+2|{491231,38953}|0
+3|{587484,467668}|0
+4|{882448,507209}|0
+5|{17326,595844}|0
+6|{236408,453230}|0
+7|{283929,237605}|0
+8|{392623,153808}|0
+9|{267864,179054}|0
+10|{428486,618138}|0
+11|{963752,141363}|0
+12|{980623,652584}|0
+13|{398411,894748}|0
+14|{559681,670919}|0
+15|{297984,171933}|0
+16|{254190,341966}|0
+17|{336766,745420}|0
+18|{380918,924250}|0
+19|{213087,263365}|0
+20|{431458,230413}|0
+21|{859208,667865}|0
+22|{683642,143136}|0
+23|{905470,76265}|0
+24|{296944,173333}|0
+25|{255319,725429}|0
+26|{791471,219070}|0
+27|{866791,772094}|0
+28|{871653,265202}|0
+29|{666841,431334}|0
+30|{936120,964824}|0
+31|{603267,190309}|0
+32|{306790,940033}|1
+33|{935729,687708}|1
+34|{864282,148815}|1
+35|{951072,295739}|1
+36|{379228,810280}|1
+37|{963604,62869}|1
+38|{953416,869073}|1
+39|{139133,250360}|1
+40|{42406,394452}|1
+41|{975789,833877}|1
+42|{613521,842579}|1
+43|{605970,485173}|1
+44|{107780,272810}|1
+45|{916507,43900}|1
+46|{237634,519773}|1
+47|{234208,544424}|1
+48|{459805,169937}|1
+49|{232131,324086}|1
+50|{318751,183202}|1
+51|{619825,697978}|1
+52|{993482,583428}|1
+53|{760847,946898}|1
+54|{452501,899980}|1
+55|{197257,494907}|1
+56|{294431,173045}|1
+57|{328783,907951}|1
+58|{15624,934752}|1
+59|{393124,123404}|1
+60|{207562,309630}|1
+61|{167303,445196}|1
+62|{829402,401511}|1
+63|{989619,289207}|1
+64|{571447,221749}|1
+65|{613292,890198}|1
+66|{404951,233116}|1
+67|{588176,398433}|1
+68|{816544,349023}|1
+69|{345330,269045}|1
+70|{249002,542587}|1
+71|{763951,543433}|1
+72|{715632,92734}|1
+73|{451384,731255}|1
+74|{27485,844507}|1
+75|{854659,235047}|1
+76|{154137,21962}|1
+77|{680243,983539}|1
+78|{423473,669861}|1
+79|{272745,994920}|1
+80|{891610,886037}|1
+81|{885117,296561}|1
+82|{119153,473293}|2
+83|{694994,935696}|2
+84|{822315,40323}|2
+85|{204741,71317}|2
+86|{582910,968691}|2
+87|{614749,298541}|2
+88|{61424,66132}|2
+89|{29796,88909}|2
+90|{910639,884455}|2
+91|{323956,64775}|2
+92|{906416,4198}|2
+93|{48314,329888}|2
+94|{674059,321058}|2
+95|{324807,565669}|2
+96|{207094,209924}|2
+97|{862229,326247}|2
+98|{683217,557222}|2
+99|{261943,505531}|2
+100|{597545,466683}|2
+\.
+
+
+CREATE TABLE knn_test_data2 (
+    id integer NOT NULL,
+    data integer[]
+);
+
+COPY knn_test_data2 (id, data) FROM stdin delimiter '|';
+1|{576848,180455}
+2|{435374,191597}
+3|{478996,496797}
+4|{257729,508791}
+5|{585706,168367}
+\.
+
+DROP TABLE if exists madlib_knn_result_classification_kd;
+
+SELECT knn('knn_train_data2','data','id',NULL,'knn_test_data2','data','id',
+           'madlib_knn_result_classification_kd',1,True,
+           'MADLIB_SCHEMA.squared_dist_norm2',False,
+           'kd_tree', 'depth=2, leaf_nodes=2');
+
+SELECT assert(count(*) > 0, 'Wrong output with kd_tree')
+FROM madlib_knn_result_classification_kd;
+
+DROP TABLE if exists madlib_knn_result_classification_kd;
+
+SELECT knn('knn_train_data2','data','id','label','knn_test_data2','data','id',
+           'madlib_knn_result_classification_kd',2,True,
+           'MADLIB_SCHEMA.squared_dist_norm2',True,
+           'kd_tree', 'depth=2, leaf_nodes=2');
+
+SELECT assert(count(*) > 0, 'Wrong output with kd_tree')
+FROM madlib_knn_result_classification_kd;
+
+DROP TABLE if exists madlib_knn_result_classification_kd;
+
+SELECT knn('knn_train_data', 'data', 'id', NULL, 'knn_test_data', 'data', 'id',
+           'madlib_knn_result_classification_kd', 2, True,
+           'MADLIB_SCHEMA.squared_dist_norm2', False, 'kd_tree',
+           'depth=2, leaf_nodes=1');
+
+SELECT assert(count(*) > 0, 'Wrong output with kd_tree')
+FROM madlib_knn_result_classification_kd;
+
+DROP TABLE if exists madlib_knn_result_classification_kd;
+
+SELECT knn('knn_train_data', 'data', 'id', NULL, 'knn_test_data', 'data', 'id',
+           'madlib_knn_result_classification_kd', 2, True,
+           'MADLIB_SCHEMA.squared_dist_norm2', False, 'kd_tree',
+           'depth=3, leaf_nodes=2');
+
+SELECT assert(count(*) > 0, 'Wrong output with kd_tree')
+FROM madlib_knn_result_classification_kd;
diff --git a/src/ports/postgres/modules/utilities/admin.py_in b/src/ports/postgres/modules/utilities/admin.py_in
index 2fa5e62..6f88fd9 100644
--- a/src/ports/postgres/modules/utilities/admin.py_in
+++ b/src/ports/postgres/modules/utilities/admin.py_in
@@ -11,6 +11,15 @@ def __get_madlib_temp_tables(target_schema):
             """.format(**locals())
     return plpy.execute(sql_get_tables_to_drop)
 
+def __get_madlib_temp_views(target_schema):
+    sql_get_tables_to_drop = """
+            SELECT quote_ident(viewname) AS viewname
+            FROM pg_views
+            WHERE viewname LIKE E'%madlib\_temp%'
+            AND quote_ident(schemaname) = '{target_schema}'
+            """.format(**locals())
+    return plpy.execute(sql_get_tables_to_drop)
+
 # ------------------------------------------------------------------------------
 def cleanup_madlib_temp_tables(schema_madlib, target_schema, **kwargs):
     """ Drop all tables matching '%madlib_temp%' in the given schema
@@ -65,3 +74,16 @@ def cleanup_madlib_temp_tables_script(schema_madlib, target_schema, **kwargs):
         sql_drop = "DROP TABLE {target_schema}.{tablename};".format(**locals())
         sql_content += sql_drop + "\n"
     return sql_content
+
+# ------------------------------------------------------------------------------
+def cleanup_madlib_temp_views(schema_madlib, target_schema, **kwargs):
+    to_drop_list = __get_madlib_temp_views(target_schema)
+    if len(to_drop_list) == 0:
+        plpy.info("No madlib temp views found in schema {target_schema}.".format(**locals()))
+        return None
+    sql_drop_all = 'DROP VIEW IF EXISTS '
+    sql_drop_all += ",".join(["{target_schema}.{viewname}".format(
+            viewname=row['viewname'], **locals()) for row in to_drop_list])
+    sql_drop_all += " CASCADE;"
+    plpy.notice("Dropping {0} views ...".format(len(to_drop_list)))
+    plpy.execute(sql_drop_all)
diff --git a/src/ports/postgres/modules/utilities/utilities.py_in b/src/ports/postgres/modules/utilities/utilities.py_in
index 1b0069f..d2f14a5 100644
--- a/src/ports/postgres/modules/utilities/utilities.py_in
+++ b/src/ports/postgres/modules/utilities/utilities.py_in
@@ -432,7 +432,6 @@ def is_pg_major_version_less_than(schema_madlib, compare_version, **kwargs):
     version = plpy.execute("select version()")[0]["version"]
     regex = re.compile('PostgreSQL\s*([0-9]+)([0-9.beta]+)', re.IGNORECASE)
     version = regex.findall(version)
-    plpy.info("{0}".format(version))
     if len(version) > 0 and int(version[0][0]) < compare_version:
         return True
     else:
diff --git a/src/ports/postgres/modules/utilities/utilities.sql_in b/src/ports/postgres/modules/utilities/utilities.sql_in
index e598566..7035940 100644
--- a/src/ports/postgres/modules/utilities/utilities.sql_in
+++ b/src/ports/postgres/modules/utilities/utilities.sql_in
@@ -114,6 +114,14 @@ PythonFunction(utilities, admin, cleanup_madlib_temp_tables_script)
 $$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cleanup_madlib_temp_views(
+    target_schema text
+)
+RETURNS void AS $$
+PythonFunction(utilities, admin, cleanup_madlib_temp_views)
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
 /**
  * @brief Return MADlib build information.
  *


Mime
View raw message