Sudoku solver

coderodde 24.06.12 11:13

Softa N x N -sudokujen ratkaisemiseksi.

 Tekstiversio  Arvo: 0 (0 ääntä)  Äänestä: +  -
//// File: toughone.txt; poista tämä kommenttirivi!
0 0 5 3 0 0 0 0 0
8 0 0 0 0 0 0 2 0
0 7 0 0 1 0 5 0 0
4 0 0 0 0 5 3 0 0
0 1 0 0 7 0 0 0 6
0 0 3 2 0 0 0 8 0
0 6 0 5 0 0 0 0 9
0 0 4 0 0 0 0 3 0
0 0 0 0 0 9 7 0 0

//// File: Solver.java
package com.mureakuha.tools.sudoku;

import java.awt.Point;
import java.util.Scanner;

/**
 * This class is responsible for solving the sudokus of size N x N, where N is the
 * second power of a natural number, i.e. 2^2 = 4, 3^2 = 9, 4^2 = 16, 25, etc.
 *
 * Two methods are provided: <code>solveIteratively()</code> solving a sudoku
 * specified with <code>setMatrix()</code> using an iterative approach, and
 * <code>solveRecursively()</code> doing the same using recursion. Both are
 * <i>backtracking</i> algorithms similar to that solving the "N queens" -problem.
 */

public class Solver {
   
    /**
     * An utility class implementing a LIFO-stack backed by a linear array.
     */

    static class IntStack {
        private int size;
        private int[] array;
       
        IntStack( int capacity ){
            this.array = new int[capacity];
        }
       
        int peek(){
            return array[size - 1];
        }
       
        int pop(){
            return array[size-- - 1];
        }
       
        void push( int i ){
            array[size++] = i;
        }
       
        boolean isEmpty(){
            return size == 0;
        }
       
        int size(){
            return size;
        }
       
        void clear(){
            size = 0;
        }
    }
   
    /**
     * Implements a very simple "hash table" suitable for keeping track of the
     * numbers in each row, column or minisquare.
     */

    static class SudokuSet {
        private boolean[] contains;
       
        SudokuSet( final int N ){
            this.contains = new boolean[N];
        }
       
        void add( int number ){
            contains[number - 1] = true;
        }
       
        boolean contains( int number ){
            return contains[number - 1];
        }
       
        void remove( int number ){
            contains[number - 1] = false;
        }
    }
   
    // An alias for documentation's sake.
    public static final int UNUSED = 0;
   
    private final int N;
    private final int M;
    private int[][] matrix;
    private int[][] solution;
    private SudokuSet[] colSets;
    private SudokuSet[] rowSets;
    private SudokuSet[][] squareSets;
    private IntStack xStack;
    private IntStack yStack;
   
    public Solver( final int N ){
        if( N < 1 )
            throw new IllegalArgumentException( "'N < 1'" );
       
        // Cannot form mini-squares?
        if( Math.floor( Math.sqrt( N ) ) != Math.ceil( Math.sqrt( N ) ) )
            throw new IllegalArgumentException(
                    N + " is not the second power of a natural number."
                    );
       
        this.M = (int) Math.sqrt( N );
        this.N = N;
        xStack = new IntStack( N * N );
        yStack = new IntStack( N * N );
    }
   
    /**
     * Sets each element from argument matrix not in the set { 1, 2, ..., N } to
     * UNUSED (0).
     */

    private void adjustMatrix(){
        for( int y = 0; y < matrix.length; y++ )
            for( int x = 0; x < matrix[y].length; x++ )
                if( matrix[y][x] < 1 || matrix[y][x] > N )
                    matrix[y][x] = UNUSED;
    }
   
    /**
     * Performs a routine check for the argument matrix.
     */

    private void checkMatrix(){
        if( matrix == null )
            throw new IllegalArgumentException( "'matrix' is null." );
       
        if( matrix.length != N )
            throw new IllegalArgumentException(
                    "'matrix' does not have exactly " + N + " rows."
                    );
       
        for( int[] row : matrix )
            if( row == null )
                throw new NullPointerException( "a row is null." );
            else if( row.length != N )
                throw new IllegalArgumentException(
                        "a row does not have exactly " + N + " elements."
                        );
    }
   
    /**
     * Converts the coordinate (x, y) to the (C) coordinates of a mini-square
     * holding the point (x, y), and stores (C) in <code>p</code>.
     *
     * @param x the X-coordinate of a matrix element.
     * @param y the Y-coordinate of a matrix element.
     * @param p the target <code>Point</code>.
     */

    private void convertToSquareCoordinates( int x, int y, Point p ){
        if( x < 0 || y < 0 || x >= N || y >= N )
            return;
       
        p.x = x / M;
        p.y = y / M;
    }
   
    /**
     * Retrieves the solution matrix, if any available.
     *
     * @return the solution matrix.
     */

    public int[][] getSolution(){
        return solution;
    }
   
    /**
     * Loads the sets, each representing the numbers present in the argument
     * matrix' column.
     */

    private void loadColumnSets(){
        SudokuSet[] colSets = new SudokuSet[N];
       
        for( int x = 0; x < N; x++ )
            colSets[x] = new SudokuSet( N );
       
        for( int y = 0; y < matrix.length; y++ )
            for( int x = 0; x < matrix[y].length; x++ )
                if( matrix[y][x] != UNUSED )
                {
                    if( colSets[x].contains( matrix[y][x] ) )
                        return;
                   
                    colSets[x].add( matrix[y][x] );
                }
       
        this.colSets = colSets;
    }
   
    /**
     * Loads the sets, each representing the numbers present in the argument
     * matrix' row.
     */

    private void loadRowSets(){
        SudokuSet[] rowSets = new SudokuSet[N];
       
        for( int y = 0; y < N; y++ )
            rowSets[y] = new SudokuSet( N );
       
        for( int y = 0; y < matrix.length; y++ )
            for( int x = 0; x < matrix[y].length; x++ )
                if( matrix[y][x] != UNUSED )
                {
                    if( rowSets[y].contains( matrix[y][x] ) )
                        return;
                   
                    rowSets[y].add( matrix[y][x] );
                }
       
        this.rowSets = rowSets;
    }
   
    /**
     * Loads the sets, each representing the numbers present in the argument
     * matrix' mini-square.
     */

    private void loadSquareSets(){
        SudokuSet[][] squareSets = new SudokuSet[M][M];
       
        for( int y = 0; y < M; y++ )
        {
            squareSets[y] = new SudokuSet[M];
           
            for( int x = 0; x < M; x++ )
                squareSets[y][x] = new SudokuSet( N );
        }
       
        Point p = new Point();
       
        for( int y = 0; y < matrix.length; y++ )
            for( int x = 0; x < matrix[y].length; x++ )
                if( matrix[y][x] != UNUSED )
                {
                    convertToSquareCoordinates( x, y, p );

                    if( squareSets[p.y][p.x].contains( matrix[y][x] ) )
                        return;

                    squareSets[p.y][p.x].add( matrix[y][x] );
                }
       
        this.squareSets = squareSets;
    }
   
    /**
     * Sets the argument matrix for this <code>Solver</code>.
     *
     * @param matrix the argument matrix.
     */

    public void setMatrix( int[][] matrix ){
        this.matrix = matrix;
    }
   
    /**
     * This method checks the validity of the argument matrix, and if valid,
     * solves it. If the given matrix does not have an unique solution, the
     * first one will be produced as specified by <i>lexicographic</i> order.
     *
     * @return this <code>Solver</code>.
     */

    public Solver solveIteratively(){ 
        checkMatrix();
        adjustMatrix();
        this.colSets = null;
        loadColumnSets();
       
        if( this.colSets == null )
            throw new IllegalStateException(
                    "duplicate numbers in a column."
                    );
       
        this.rowSets = null;
        loadRowSets();
       
        if( this.rowSets == null )
            throw new IllegalStateException(
                    "duplicate numbers in a row."
                    );
       
        this.squareSets = null;
        loadSquareSets();
       
        if( this.squareSets == null )
            throw new IllegalStateException(
                    "duplicate numbers in a square."
                    );
       
        xStack.clear();
        yStack.clear();
       
        // Load the initial coordinate (0, 0).
        xStack.push( 0 );
        yStack.push( 0 );
       
        solution = new int[N][N];
       
        // Set the first element at (0, 0).
        if( matrix[0][0] != UNUSED )
            solution[0][0] = matrix[0][0];
        else
            for( int i = 1; i <= N; i++ )
                if(
                    !colSets[0].contains( i ) &&
                    !rowSets[0].contains( i ) &&
                    !squareSets[0][0].contains( i )
                )
                {
                    solution[0][0] = i;
                    colSets[0].add( i );
                    rowSets[0].add( i );
                    squareSets[0][0].add( i );
                    break;
                }
       
        if( solution[0][0] == UNUSED )
        {
            // Should not happen. An exception is thrown prior to getting into
            // this state.
            solution = null;
            return this;
        }
       
        Point next = new Point();
        Point top = new Point();
        Point p = new Point();
       
        // Main loop simulating recursion.
        outer:
        while( xStack.size() < N * N )
        {
            top.x = xStack.peek();
            top.y = yStack.peek();
           
            if( top.x == N - 1 )
            {
                // "Carriage return." Proceed to the next line.
                next.x = 0;
                next.y = top.y + 1;
            }
            else
            {
                // No new line needed, just increment x-coordinate.
                next.x = top.x + 1;
                next.y = top.y;
            }
           
            if( matrix[next.y][next.x] != UNUSED )
            {
                // Just load a predefined number from the argument matrix.
                solution[next.y][next.x] = matrix[next.y][next.x];
                // "Recur" further.
                xStack.push( next.x );
                yStack.push( next.y );
                continue;
            }
           
            // Find a number to set at (next.x, next.y).
            for( int i = 1; i <= N; i++ )
                if(
                    !colSets[next.x].contains( i ) &&
                    !rowSets[next.y].contains( i )
                )
                {
                    convertToSquareCoordinates( next.x, next.y, p );
                   
                    if( !squareSets[p.y][p.x].contains( i ) )
                    {
                        // Found it! Add to data structures.
                        solution[next.y][next.x] = i;
                        xStack.push( next.x );
                        yStack.push( next.y );
                        colSets[next.x].add( i );
                        rowSets[next.y].add( i );
                        squareSets[p.y][p.x].add( i );
                        // "Recur" further.
                        continue outer;
                    }
                }
           
            // Backtracking loop.
            for(;;)
            {
                // Skip predefined elements at the stack top.
                while(
                    xStack.size() > 0 &&
                    stackTopPointsToPredefined()
                )
                {
                    xStack.pop();
                    yStack.pop();
                }
               
                if( xStack.isEmpty() )
                {
                    // Should not happen. An exception is thrown prior to
                    // getting into this state.
                    solution = null;
                    return this;
                }

                int x = xStack.peek();
                int y = yStack.peek();
                int old = solution[y][x];
                convertToSquareCoordinates( x, y, p );

                // Increment the number at the stack top starting from x + 1,
                // where x is the old value.
                for( int i = old + 1; i <= N; i++ )
                    if(
                        !colSets[x].contains( i ) &&
                        !rowSets[y].contains( i ) &&
                        !squareSets[p.y][p.x].contains( i )
                    )
                    {
                        // Found next number fitting in.
                        // Remove the old one.
                        colSets[x].remove( old );
                        rowSets[y].remove( old );
                        squareSets[p.y][p.x].remove( old );
                        // Add the new one.
                        colSets[x].add( i );
                        rowSets[y].add( i );
                        squareSets[p.y][p.x].add( i );
                        solution[y][x] = i;
                        // Continue normal "recursion".
                        continue outer;
                    }
               
                // Nothing fits in at the location pointed by coordinate stacks;
                // remove the top and, thus, backtrack further.
                xStack.pop();
                yStack.pop();
                colSets[x].remove( old );
                rowSets[y].remove( old );
                squareSets[p.y][p.x].remove( old );
            }
        }
       
        return this;
    }
   
    /**
     * This method checks the validity of the argument matrix, and if valid,
     * solves it. If the given matrix does not have an unique solution, the
     * first one will be produced as specified by <i>lexicographic</i> order.
     *
     * @return this <code>Solver</code>.
     */

    public Solver solveRecursively(){
        checkMatrix();
        adjustMatrix();
        this.colSets = null;
        loadColumnSets();
       
        if( this.colSets == null )
            throw new IllegalStateException(
                    "duplicate numbers in a column."
                    );
       
        this.rowSets = null;
        loadRowSets();
       
        if( this.rowSets == null )
            throw new IllegalStateException(
                    "duplicate numbers in a row."
                    );
       
        this.squareSets = null;
        loadSquareSets();
       
        if( this.squareSets == null )
            throw new IllegalStateException(
                    "duplicate numbers in a square."
                    );
       
        this.solution = new int[N][N];
       
        // Begin recursion.
        if( !solveRecursivelyImpl( 0, 0 ) )
            solution = null;
       
        return this;
    }
   
    /**
     * Implements a backtracking algorithm for solving sudokus.
     *
     * @param x the X-coordinate of an element to set next.
     * @param y the Y-coordinate of an element to set next.
     *
     * @return <code>true</code> if this call recurs to a solution,
     *  <code>false</code> otherwise.
     */

    private boolean solveRecursivelyImpl( int x, int y ){
        if( x == N )
        {
            // Begin next row.
            x = 0;
            y++;
        }
       
        // Reached the bottom of recursion? (Solution found?)
        if( y == N )
            return true;
       
        if( matrix[y][x] != UNUSED )
        {
            // Just load a predefined number from matrix to solution, and
            // proceed further.
            solution[y][x] = matrix[y][x];
            return solveRecursivelyImpl( x + 1, y );
        }
        else
        {
            // Find least number fitting in the current position (x, y).
            for( int i = 1; i <= N; i++ )
                if(
                    !colSets[x].contains( i ) &&
                    !rowSets[y].contains( i )
                )
                {
                    Point p = new Point();
                    convertToSquareCoordinates( x, y, p );
                   
                    if( !squareSets[p.y][p.x].contains( i ) )
                    {
                        // Got here? Then i fits at (x, y), so add it.
                        solution[y][x] = i;
                        colSets[x].add( i );
                        rowSets[y].add( i );
                        squareSets[p.y][p.x].add( i );
                       
                        if( solveRecursivelyImpl( x + 1, y ) )
                            // Solution found if got here, just return.
                            return true;
                        else
                        {
                            // No solution for this i; iterate further.
                            colSets[x].remove( i );
                            rowSets[y].remove( i );
                            squareSets[p.y][p.x].remove( i );
                        }
                    }
                }
           
            // No number fits at this (x, y), backtrack a little.
            return false;
        }
    }
   
    /**
     * Checks whether the current stack top points to a predefined element from
     * <code>matrix</code>.
     *
     * @return <code>true</code> if the stack top points to a predefined
     * element, <code>false</code> otherwise.
     */

    private boolean stackTopPointsToPredefined(){
        return matrix[yStack.peek()][xStack.peek()] != UNUSED;
    }
   
    /**
     * Checks whether the argument solution matrix is a valid sudoku solution.
     *
     * @param matrix the matrix to run the test against.
     *
     * @return <code>true</code> if <code>matrix</code> specifies a valid
     * sudoku solution, <code>false</code> otherwise.
     */

    public static boolean isValidSolutionMatrix( int[][] matrix ){
        if( matrix == null )
            return false;
       
        final int N = matrix.length;
       
        for( int[] row : matrix )
            if( row == null || row.length != N )
                return false;
       
        if( Math.ceil( Math.sqrt( N ) ) != Math.floor( Math.sqrt( N ) ) )
            return false;
       
        final int M = (int) Math.sqrt( N );
       
        SudokuSet[] colSets = new SudokuSet[N];
        SudokuSet[] rowSets = new SudokuSet[N];
        SudokuSet[][] squareSets = new SudokuSet[M][M];
       
        for( int i = 0; i < N; i++ )
        {
            colSets[i] = new SudokuSet( N );
            rowSets[i] = new SudokuSet( N );
        }
       
        for( int y = 0; y < M; y++ )
            for( int x = 0; x < M; x++ )
                squareSets[y][x] = new SudokuSet( N );
       
        Point p = new Point();
       
        for( int y = 0; y < N; y++ )
            for( int x = 0; x < N; x++ )
            {
                int current = matrix[y][x];
               
                if( current < 1 || current > N )
                    return false;
                else if(
                    colSets[x].contains( current ) ||
                    rowSets[y].contains( current )
                )
                    return false;
                else
                {
                    p.x = x / M;
                    p.y = y / M;
                   
                    if( squareSets[p.y][p.x].contains( current ) )
                        return false;
                   
                    colSets[x].add( current );
                    rowSets[y].add( current );
                    squareSets[p.y][p.x].add( current );
                }
            }
       
        return true;           
    }
   
    /**
     * Prints an integer matrix using a particular field size for each element.
     *
     * @param matrix the integer matrix.
     * @param fieldSize the field size (in characters).
     */

    public static void printMatrix( int[][] matrix, int fieldSize ){
        for( int y = 0; y < matrix.length; y++ )
        {
            for( int x = 0; x < matrix[y].length; x++ )
                System.out.printf( "%" + fieldSize + "d ", matrix[y][x] );
           
            System.out.println();
        }
    }
   
    /**
     * The entry point of this software.
     *
     * @param args the command line arguments.
     */

    public static void main( String... args ){
        int N = 9;
        boolean recursive = true;
       
        if( args.length > 0 )
        {
            try {
                N = Integer.parseInt( args.length > 1 ? args[1] : args[0] );
            }
            catch( NumberFormatException e ){
                System.out.println(
                        "usage: java -jar THIS.jar [[-i] N]\n" +
                        "  where N is the problem size\n" +
                        "  (the second power of a natural number),\n" +
                        "  -i as iterative algorithm, omit for recursive."
                        );
               
                System.exit( -1 );
            }
           
            if( N < 1 )
            {
                N = 4;
               
                System.out.println( "N corrected." );
            }
           
            if( args.length > 1 && args[0].equals( "-i" ) )
                recursive = false;
        }
       
        System.out.println( "Problem size: " + N );
        System.out.println(
                "Method: " + (recursive ? "recursive." : "iterative.")
                );
       
        Solver solver = new Solver( N );
        int[][] matrix = new int[N][N];
        Scanner scanner = new Scanner( System.in );
       
        // Read the matrix line by line from stdin until matrix filled or
        // cannot read an integer.
        outer:
        for( int y = 0; y < N; y++ )
            for( int x = 0; x < N; x++ )
                if( scanner.hasNextInt() )
                    matrix[y][x] = scanner.nextInt();
                else
                    break outer;
       
        scanner.close();
        solver.setMatrix( matrix );
        long ta = System.currentTimeMillis();
       
        if( recursive )
            solver.solveRecursively();
        else
            solver.solveIteratively();
       
        long tb = System.currentTimeMillis();
        int[][] solution = solver.getSolution();
       
        if( solution != null )
        {
            int fieldSize;
           
            if( N < 10 )
                fieldSize = 1;
            else if(
                Math.ceil( Math.log10( N ) ) != Math.floor( Math.log10( N ) )
            )
                fieldSize = (int) Math.ceil( Math.log10( N ) );
            else
                fieldSize = (int)(Math.ceil( Math.log10( N ) )) + 1;
           
            printMatrix( solution, fieldSize );
            System.out.println(
                    "Valid solution: " + isValidSolutionMatrix( solution )
                    );
        }
        else
            System.out.println( "No solutions." );
       
        System.out.println( "Time: " + (tb - ta) + " ms." );
    }
}
 

coderodde 11:18 24.6.12 
Jos saat käännettyä JAR - tiedoston, niin kannattaa kokeilla seuraavaa:
java -jar xxx.jar 9 < /path/to/toughone.txt
tai
java -jar xxx.jar -i 9 < /path/to/toughone.txt

Jos sinulla on vain käännetyt .class-tiedostot, homma menee jotakuinkin näin:
java com.mureakuha.tools.sudoku.Solver -i 9 < /path/to/toughone.txt
kansiosta, joka sisältää "com"-kansion.
coderodde 11:21 24.6.12 
Olisi riittänyt vain yksi kahdesta algoritmista, mutta itseäni kiinnosti, joko iteratiivisen ja rekursiivisen toteutuksen välillä olisi mitään suoriutuskykyeroa.
editoitu: 13:25 24.6.12
coderodde 11:23 24.6.12 
Parametrin N on oltava luonnollisen luvun toka potenssi. (Muuten neliömuotoisia miniruudukoita ei pysty muodostamaan, jolloin heitellään poikkeuksia.)