import java.awt.event.*;
import java.awt.*;
import java.applet.*;
import java.lang.Math;

import java.util.Vector;

public class Nearest extends Applet implements ActionListener
{
    DrawCanvas canvas;  // Main canvas
    TextField MessageText;  // Text field used to display messages
    TextField knnText;      // Text field to show the number of Nearest neighbor to search
    TextField nPointText;   // Text field to show the number of points in the data set
    Button ResetButton;
    Button RandomButton;
    Button NormalRandomButton;
    Button NNButton;        // Button to switch mode between Search NN/edit point
    Button AddNNButton;     // Increase Nearest neighbor to search
    Button RemoveNNButton;  // Decrease Nearest neighbor to search
    Button SimButton;       // Activate a simulation

    public void init() 
    {
        Rectangle bound = getBounds();
        
        add(new Label("Number of points:"));
        nPointText = new TextField(4);
        nPointText.setText("0");
        nPointText.setEditable(false);
        add(nPointText);
        
        add(new Label("Number of Nearest neighbor to search:"));
        knnText = new TextField(3);
        knnText.setText("1");
        knnText.setEditable(false);
        add(knnText);
        
        AddNNButton = new Button("+");
        RemoveNNButton = new Button("-");
        AddNNButton.addActionListener(this);
        RemoveNNButton.addActionListener(this);
        add(AddNNButton);
        add(RemoveNNButton);

        MessageText = new TextField("Welcome to the nearest neighbor applet!!");
        MessageText.setEditable(false);        
        canvas = new DrawCanvas(MessageText, nPointText);
        canvas.setSize(bound.width-10,bound.height-140);
        add(canvas);

        ResetButton = new Button("Reset Points");
        add(ResetButton);
        ResetButton.addActionListener(this);
        
        RandomButton = new Button("Uniform Dist");
        add(RandomButton);
        RandomButton.addActionListener(this);
        
        NormalRandomButton = new Button("Normal Dist");
        add(NormalRandomButton);
        NormalRandomButton.addActionListener(this);
        
        add(new Label("Switch mode:"));
        NNButton = new Button("Nearest Neighbor");
        add(NNButton);
        NNButton.addActionListener(this);

        add(new Label("Simulate a 1000 random tests points:"));
        SimButton = new Button("Go Simulation!!!");
        add(SimButton);
        SimButton.addActionListener(this);

        add(MessageText);
    }
    
    public void destroy() 
    {
        remove(canvas);
    }

    public void paint(Graphics g) 
    {
       	Rectangle bound = getBounds();

    	g.setColor(Color.lightGray);
	    g.draw3DRect(0, 0, bound.width-1, bound.height-1, true);
    }
    
    public void actionPerformed(ActionEvent e)
    {
        if(e.getSource() == ResetButton)
        {
            MessageText.setText("Points reset");
            canvas.resetPoints();
        }
        else if(e.getSource() == SimButton)
        {
            canvas.DoSimulation();
        }
        else if(e.getSource() == RandomButton)
        {
            MessageText.setText("Random Points generated");
            canvas.randomPoints(100);
        }
        else if(e.getSource() == NormalRandomButton)
        {
            MessageText.setText("Normal Random points generated");
            canvas.normalRandomPoints(100);
        }
        else if(e.getSource() == NNButton)
        {
            if(canvas.getClickMode() == DrawCanvas.ADDPOINTS)
            {
                MessageText.setText("Search mode");
                canvas.setClickMode(DrawCanvas.FINDNEAREST);
                NNButton.setLabel("Edit Points");
                ResetButton.setEnabled(false);
                RandomButton.setEnabled(false);
                NormalRandomButton.setEnabled(false);
            }
            else
            {
                MessageText.setText("Edit mode");
                canvas.setClickMode(DrawCanvas.ADDPOINTS);
                NNButton.setLabel("Nearest Neighbor");
                ResetButton.setEnabled(true);
                RandomButton.setEnabled(true);
                NormalRandomButton.setEnabled(true);
            }                
        }
        else if(e.getSource() == AddNNButton)
        {
            canvas.setKN(canvas.getKN()+1);
            knnText.setText(String.valueOf(canvas.getKN()));
        }
        else if(e.getSource() == RemoveNNButton)
        {
            canvas.setKN(canvas.getKN()-1);
            knnText.setText(String.valueOf(canvas.getKN()));
        }
    }   

    public String getAppletInfo() 
    {
        return "A simple drawing program.";
    }
}

class DrawCanvas extends Canvas implements MouseListener 
{
    // Define the constant we will need
    public static final int FINDNEAREST = 0;
    public static final int ADDPOINTS = 1;
 
    // Nearest Neighbor finder
    NNFinder nnfind;
    int kn;
    // Flag to indicate the mode of operation (state)
    int mode;
    boolean DrawAllFlag;
    // Dynamic array of points
    Vector points;
    // Current point (for nearest neighbor)
    Point point;
    
    TextField MessageText;
    TextField nPointText;
    
    public DrawCanvas(TextField mtext, TextField nptstext) 
    {
        mode = ADDPOINTS;
        DrawAllFlag = false;
        points = new Vector();
        MessageText = mtext;
        nPointText = nptstext;
        kn = 1;
        setBackground(Color.white);
        addMouseListener(this);
    }
    
    public void setKN(int k)
    {
        if(k > 0 && k < points.size()) kn = k;
    }
    
    public int getKN()
    {
        return kn;
    }
    
    public void setClickMode(int newmode) 
    {
        // Erase the current point
        point = null;
        mode = newmode;
        if(mode == FINDNEAREST)
        {
            nnfind = new NNFinder(points);
        }
        // Redraw all the DrawPanel
        DrawAllFlag = true;
        repaint();
    }

    public void DoSimulation()
    {
        Dimension dim = getSize();
        Point pc, point;
        int k, i;
        double res = 0.0, brute;
        nnfind = new NNFinder(points);
        for(i=0; i < 1000; i++)
        {
            point = new Point(  (int)(Math.random()*(double)dim.width), 
                                (int)(Math.random()*(double)dim.height));
            // Find the first nearest neighbor
            pc = nnfind.FindFirstNN(point);
            if(kn > 1)
            {
                for(k = 1; k < kn; k++)
                    // Find the next nearest neighbor
                    pc = nnfind.FindNextNN();
            }
            res += (double)nnfind.getCompare();
        }
        res /= 1000.0;
        // We compute the brute force number
        for(k=0,brute=0.0; k < kn; k++) brute += (double)(nnfind.ndata - k);
        MessageText.setText("Average Number Of Compare: "+(int)res+"/"+(int)brute+" ("
            + (int)(res*100/brute)+"%)");
    }

    public int getClickMode()
    {
        return mode;
    }
    
    public void resetPoints()
    {
        points.removeAllElements();
        point = null;
        DrawAllFlag = true;
        // Update the number of point text field
        nPointText.setText("0");
        repaint();
    }
    public void randomPoints(int npts)
    {
        Dimension dim = getSize();
        int npoints = points.size();
        //points.removeAllElements();
        points.ensureCapacity(npoints+npts);
        for(int i = 0; i < npts; i++)
        {
            Point p = new Point((int)(Math.random()*(double)dim.width), 
                                (int)(Math.random()*(double)dim.height));
            points.addElement(p);
        }
        // Set the redraw of all points
        DrawAllFlag = true;
        point = null;
        // Update the number of point text field
        nPointText.setText(String.valueOf(npoints+npts));
        repaint();
    }
    
    // Return a random value according to the normal distribution
    // Code taken from : http://jeff.cs.mcgill.ca/~luc/rng.html
    public double randNorm(double mean, double sigma)
    {
        double u = Math.sqrt(-2.0*Math.log(Math.random()));
        return (u+mean)*sigma;
    }
    
    public void normalRandomPoints(int npts)
    {
        double R, Theta, x,y;
        Dimension dim = getSize();
        int npoints = points.size();
        points.ensureCapacity(npoints+npts);
        // Create a new bunch of points using the normal distribution
        int i = 0, u, v;
        while(i < npts)
        {
            R = randNorm(0.0, 1.0)*(double)dim.height/4.0;
            Theta = Math.random()*Math.PI*2.0;
            x = dim.width/2 + R*Math.cos(Theta);
            y = dim.height/2 + R*Math.sin(Theta);
            u = (int)x; v = (int)y;
            if(u > 0 && u < dim.width && v > 0 && v < dim.height)
            {
                Point p = new Point((int)u,(int)v);
                points.addElement(p);
                i++;
            }
        }
        // Set the redraw of all points
        DrawAllFlag = true;
        point = null;
        // Update the number of point text field
        nPointText.setText(String.valueOf(npoints+npts));
        repaint();
    }
    
    public void mousePressed(MouseEvent e) 
    {
        e.consume();
        switch (mode) 
        {
            case FINDNEAREST:
                point = new Point(e.getX(), e.getY());                
                // Tell the repaint routine to redraw all
                DrawAllFlag = true;
                repaint();
                break;
            case ADDPOINTS:
                point = new Point(e.getX(), e.getY());
                points.addElement(point);
                // Update the number of point text field
                nPointText.setText(String.valueOf(points.size()));
                repaint();         
            break;
        }
    }
    public void mouseReleased(MouseEvent e) {}
    public void mouseEntered(MouseEvent e) {}
    public void mouseExited(MouseEvent e) {}
    public void mouseClicked(MouseEvent e) {}
    
    public void update(Graphics g)
    {
        if(DrawAllFlag)
        {
            // Clear the panel surface to clearing color
            Rectangle bound = getBounds();
            g.clearRect(0,0,bound.width,bound.height);
            // Set the clip rect so that nothing goes outside
            // of the canvas
            //g.clipRect(0,0,bound.width,bound.height);
        }
        paint(g);
    }
    
    public void paint(Graphics g) 
    {
       	Rectangle bound = getBounds();

    	g.setColor(Color.lightGray);
	    g.draw3DRect(0, 0, bound.width, bound.height, false);

        if(DrawAllFlag)
        {
            DrawAllFlag = false;
            int np = points.size();        
            // draw each points
            g.setColor(Color.blue);
            g.setPaintMode();
            for (int i=0; i < np; i++) 
            {
                Point p = (Point)points.elementAt(i);
                g.drawOval(p.x-2, p.y-2, 4,4);
            }
        }
        if(mode == ADDPOINTS)
        {
            if(point != null)
            {
                g.setColor(Color.blue);
                g.drawOval(point.x-2, point.y-2, 4,4);
            }
        }
        else if(mode == FINDNEAREST)
        {
            if(point != null)
            {
                Point pc;
                int k;
                // Draw the point clicked in red
                g.setColor(Color.red);
                g.drawOval(point.x-2, point.y-2, 4,4);                
                // Find the first nearest neighbor
                pc = nnfind.FindFirstNN(point);
                // Draw the nearest neighbor in green
                g.setColor(Color.green);
                g.drawOval(pc.x-2, pc.y-2, 4,4);                                
                if(kn > 1)
                {
                    for(k = 1; k < kn; k++)
                    {
                        // Find the next nearest neighbor
                        pc = nnfind.FindNextNN();
                        // Draw the next nearest neighbor in green                       
                        g.drawOval(pc.x-2, pc.y-2, 4,4);             
                    }
                }

                // Draw the empty circle in gray
                g.setColor(Color.gray);
                float fr = (int)Math.sqrt((double)nnfind.ComputeDistance(point, pc));
                g.drawOval(point.x-(int)fr, point.y-(int)fr, (int)(2*fr), (int)(2*fr));
                // Draw the search region in light gray
                g.setColor(Color.lightGray);
                int r = nnfind.getMaxRadius();
                if(r > 0) 
                {
                    // First, draw a circle
                    g.drawOval(point.x-r, point.y-r, 2*r, 2*r);
                    // Now, we draw lines
                    if(nnfind.getProjectionAxe() == NNFinder.PROJX)
                    {
                        // Draw vertical lines
                        g.drawLine(point.x-r, 0, point.x-r, bound.height);
                        g.drawLine(point.x+r, 0, point.x+r, bound.height);
                    }
                    else
                    {
                        // Draw horisontal lines
                        g.drawLine(0, point.y-r, bound.width, point.y-r);
                        g.drawLine(0, point.y+r, bound.width, point.y+r);
                    }
                }
                // Then display the comment in message text field (number of compares)            
                for(k=0,r=0; k < kn; k++) r += (nnfind.ndata - k);
                MessageText.setText("Number of compare: "+nnfind.getCompare()+"/"+r+" ("
                        + nnfind.getCompare()*100/r+"%)");
            }
        }
    }
}

class NNFinder
{
    // Projection axes defines
    public static final int PROJX = 1;
    public static final int PROJY = 2;
    
    int n;          // This is the number of point expected to be tested
    int projaxe;  // Current projection axe
    int compare; // Number of compare made to find the nearest
    int maxradius;  // Maximum radius of search in the array
    int ndata;      // Number of data in the arrays
    Point data[];   // Point array
    Point xindex[];   // Index for coordinates X (x is the index, y is the value)
    Point yindex[];   // Index for coordinates Y (x is the index, y is the value)    
    boolean flags[];    // Array of flags that indicate the validity of point (true or false)
    
    Point cindex[];     // Pointer to current index
    int cvalue, cind;   // Current values for the search
    Point cp;           // Current point being tested
    
    // Constructor of NNFinder
    public NNFinder(Vector points)
    {
        n = 10; // We expect to test 10 points... (this is arbitrary)
        // Copy the point vector into an array
        ndata = points.size();
        // We create the arrays
        data = new Point[ndata];
        xindex = new Point[ndata];
        yindex = new Point[ndata];
        flags = new boolean[ndata];
        // We initialise the data
        for(int i=0; i < ndata; i++) 
        {   
            data[i] = (Point)points.elementAt(i);
            // Create a new point which will have
            // field x : the index of the original point in the array (data)
            // field y : the value of the projected point on the axe
            xindex[i] = new Point(i, data[i].x);
            yindex[i] = new Point(i, data[i].y);
            flags[i] = true;    // The point is valid
        }
        // Sort the index for axe X
        cindex = xindex;
        BubbleSort();
        // Sort the index for axe Y
        cindex = yindex;
        BubbleSort();
        compare = 0;
    }

    
    // Simple bubble sort. Uses the index
    // array to determine the values of the
    // element of the array
    public void BubbleSort()
    {
        Point ptmp;
	    for (int i = ndata-1; i >= 0; i--) 
        {
	        for (int j = 0; j<i; j++)
            {
		        if (cindex[j].y > cindex[j+1].y) 
                {   // Swap the points
                    ptmp = cindex[j];
                    cindex[j] = cindex[j+1];
                    cindex[j+1] = ptmp;
                }
    		}
	    }
	}
    
    // Returns the number of compare done for 
    // finding the nearest point
    public int getCompare() { return compare;}
    
    // Returns the maximum radius of search
    // in the arrays of points.
    public int getMaxRadius() {return maxradius;}

    // Returns the projection axe used to find
    // the nearest neighbor.
    public int getProjectionAxe() {return projaxe;}
    
    // Dichotomic search in an ordered array using
    // index
    private int DichoSearchIndex(int value)
    {
        int inf, sup, centre;
        inf = 0;
        sup = ndata-1;
        
        // Check for obious case
        if(value <= cindex[inf].y) return inf;
        else if(value >= cindex[sup].y) return sup;

        // Search until we have at least two elements
        while(sup > inf)
        {
            centre = (inf+sup)/2;
            if(cindex[centre].y == value) return centre;
            else if(cindex[centre].y < value) inf = centre+1;
            else sup = centre-1;
        }
        return inf;
    }
    
    // This function will reset the all the flags to true
    private void ResetFlags()
    {
        for(int i=0; i < ndata; i++) flags[i] = true;
    }
     
    public Point FindFirstNN(Point p)
    {
        int index1, index2, i,j;
        int sparse1, sparse2;
        int xdim, ydim;
        float s1,s2;
        // Compute the dimension of the pointset
        xdim = xindex[ndata-1].y - xindex[0].y;
        ydim = yindex[ndata-1].y - yindex[0].y;
        
        // Remember the point being tested
        cp = p;    
        
        // Reset the flags
        ResetFlags();   
        
        // Find the point on the projected axes
        cindex = xindex;
        index1 = DichoSearchIndex(p.x);
        cindex = yindex;
        index2 = DichoSearchIndex(p.y);
        
        // Mesure the sparsity of axe X
        i = index1 - n/2;
        if(i < 0) i = 0;
        j = index1 + n/2;
        if(j >= ndata) j = ndata-1;
        sparse1 = xindex[j].y - xindex[i].y;
        // Normalize sparsity
        s1 = (float)sparse1/(float)xdim;

        // Mesure the sparsity of axe Y
        i = index2 - n/2;
        if(i < 0) i = 0;
        j = index2 + n/2;
        if(j >= ndata) j = ndata-1;
        sparse2 = yindex[j].y - yindex[i].y;
        // Normalise sparsity
        s2 = (float)sparse2/(float)ydim;
        
        if(s1 > s2)
        {   // We take the x axe
            cindex = xindex;
            cvalue = p.x;
            cind = index1;
            projaxe = PROJX;
        }
        else
        {   // We take the y axe
            cindex = yindex;
            cvalue = p.y;
            cind = index2;
            projaxe = PROJY;            
        }
        
        // Init the number of compare
        compare = 0;
        maxradius = 0;
        return FindNextNN();
    }
    
    public Point FindNextNN()
    {
        float mindist,dist;
        int mini;
        int i, il, ir;
        
        // Find the radius of the circle
        // cind is already at the left of the test point
        il = cind;
        // Find the closest next valid point
        while(flags[cindex[il].x] == false && il > 0) il--;
        if(flags[cindex[il].x] == true) 
        {
            // Compute the distance
            mindist = ComputeDistance(cp, data[cindex[il].x]);
            compare++;
        }
        else
            // Put something big no chance of having a distance bigger than that
            mindist = 500*500;
        
        // Here, we must verify if it's not the end already
        if(cind < ndata-1)
        {
            ir = cind+1;
            while(flags[cindex[ir].x] == false && ir < ndata-1) ir++;
            if(flags[cindex[ir].x] == true)
            {
                dist = ComputeDistance(cp, data[cindex[ir].x]);
                compare++;
            }
            else dist = 500*500;

            if(mindist < dist) 
            {
                maxradius = (int)Math.sqrt((double)mindist);
                mini = il;
            }
            else
            {
                mini = ir;
                maxradius = (int)Math.sqrt((double)dist);
                mindist = dist;
            }
        }
        else
        {
            mini = il;
            ir = ndata-1;
            maxradius = (int)Math.sqrt((double)mindist);
        }
            

        // Search to the left of cind
        i = il-1;
        while(i > 0 && maxradius > Math.abs(cvalue-cindex[i].y))
        {
            if(flags[cindex[i].x] == true)
            {
                dist = ComputeDistance(cp, data[cindex[i].x]);
                compare++;
                if(dist < mindist)
                {
                    mindist = dist;
                    mini=i;
                }
            }
            // Go to the left
            i--;
        }
        
        // Search to the right of cind
        i = ir+1;
        while(i < ndata && maxradius > Math.abs(cindex[i].y - cvalue))
        {
            if(flags[cindex[i].x] == true)
            {
                dist = ComputeDistance(cp, data[cindex[i].x]);
                compare++;
                if(dist < mindist)
                {
                    mindist = dist;
                    mini=i;
                }
            }
            i++; // Go to the right
        }
        
        // Set the flag of the point found to false so that
        // we don't find it again for next search
        flags[cindex[mini].x] = false;
       
        // return the closest point
        return data[cindex[mini].x];
    }
    // A crude way to find the nearest neighbor
    public Point FindNearestNeighborCrude(Point p)
    {
        float mindist,dist;
        int mini = 0;
        // Init compare
        compare = 0;
        mindist = ComputeDistance(p, data[mini]);
        for(int i=0; i < ndata; i++)
        {
            dist = ComputeDistance(p, data[i]);
            compare++;
            if(dist < mindist)
            {
                mini = i;
                mindist = dist;
            }
        }
        return data[mini];
    }
    // Compute the SQUARED distance between to points
    public float ComputeDistance(Point p1, Point p2)
    {
        float dx,dy;
        dx = p2.x - p1.x;
        dy = p2.y - p1.y;
        return (dx*dx + dy*dy);
    }
}

