隨筆 - 22  文章 - 467  trackbacks - 0
          <2025年6月>
          25262728293031
          1234567
          891011121314
          15161718192021
          22232425262728
          293012345

          常用鏈接

          留言簿(25)

          隨筆分類(74)

          文章分類(1)

          收藏夾(277)

          在線工具

          在線資料

          最新隨筆

          搜索

          •  

          積分與排名

          • 積分 - 217706
          • 排名 - 261

          最新評論

          閱讀排行榜

          評論排行榜

          矩陣乘法的多線程實現:

          /**    
          * @Title: MultiThreadMatrix.java 
          * @Package matrix 
          * @Description: 多線程計算矩陣乘法 
          @author Aloong 
          * @date 2010-10-28 下午09:45:56 
          @version V1.0 
          */
           

          package matrix;

          import java.util.Date;


          public class MultiThreadMatrix
          {
              
              
          static int[][] matrix1;
              
          static int[][] matrix2;
              
          static int[][] matrix3;
              
          static int m,n,k;
              
          static int index;
              
          static int threadCount;
              
          static long startTime;
              
              
          public static void main(String[] args) throws InterruptedException
              
          {
                  
          //矩陣a高度m=100寬度k=80,矩陣b高度k=80寬度n=50 ==> 矩陣c高度m=100寬度n=50
                  m = 1024;
                  n 
          = 1024;
                  k 
          = 1024;
                  matrix1 
          = new int[m][k];
                  matrix2 
          = new int[k][n];
                  matrix3 
          = new int[m][n];
                  
                  
          //隨機初始化矩陣a,b
                  fillRandom(matrix1);
                  fillRandom(matrix2);
                  startTime 
          = new Date().getTime();
                  
                  
          //輸出a,b
          //        printMatrix(matrix1);
          //        printMatrix(matrix2);
                  
                  
          //創建線程,數量 <= 4
                  for(int i=0; i<4; i++)
                  
          {
                      
          if(index < m)
                      
          {
                          Thread t 
          = new Thread(new MyThread());
                          t.start();
                      }
          else 
                      
          {
                          
          break;
                      }

                  }

                  
                  
          //等待結束后輸出
                  while(threadCount!=0)
                  
          {
                      Thread.sleep(
          20);
                  }

                  
          //        printMatrix(matrix3);
                  long finishTime = new Date().getTime();
                  System.out.println(
          "計算完成,用時"+(finishTime-startTime)+"毫秒");
              }

              
              
          static void printMatrix(int[][] x)
              
          {
                  
          for (int i=0; i<x.length; i++)
                  
          {
                      
          for(int j=0; j<x[i].length; j++)
                      
          {
                          System.out.print(x[i][j]
          +" ");
                      }

                      System.out.println(
          "");
                  }

                  System.out.println(
          "");
              }

              
              
          static void fillRandom(int[][] x)
              
          {
                  
          for (int i=0; i<x.length; i++)
                  
          {
                      
          for(int j=0; j<x[i].length; j++)
                      
          {
                          
          //每個元素設置為0到99的隨機自然數
                          x[i][j] = (int) (Math.random() * 100);
                      }

                  }

              }


              
          synchronized static int getTask()
              
          {
                  
          if(index < m)
                  
          {
                      
          return index++;
                  }

                  
          return -1;
              }


          }


          class MyThread implements Runnable
          {
              
          int task;
              @Override
              
          public void run()
              
          {
                  MultiThreadMatrix.threadCount
          ++;
                  
          while( (task = MultiThreadMatrix.getTask()) != -1 )
                  
          {
                      System.out.println(
          "進程: "+Thread.currentThread().getName()+"\t開始計算第 "+(task+1)+"");
                      
          for(int i=0; i<MultiThreadMatrix.n; i++)
                      
          {
                          
          for(int j=0; j<MultiThreadMatrix.k; j++)
                          
          {
                              MultiThreadMatrix.matrix3[task][i] 
          += MultiThreadMatrix.matrix1[task][j] * MultiThreadMatrix.matrix2[j][i];
                          }

                      }

                  }

                  MultiThreadMatrix.threadCount
          --;
              }

          }



          單線程:

          /**    
          * @Title: SingleThreadMatrix.java 
          * @Package matrix 
          * @Description: 單線程計算矩陣乘法 
          @author Aloong 
          * @date 2010-10-28 下午11:33:18 
          @version V1.0 
          */
           

          package matrix;

          import java.util.Date;


          public class SingleThreadMatrix
          {
              
          static int[][] matrix1;
              
          static int[][] matrix2;
              
          static int[][] matrix3;
              
          static int m,n,k;
              
          static long startTime;
              
              
          public static void main(String[] args)
              
          {
                  m 
          = 1024;
                  n 
          = 1024;
                  k 
          = 1024;
                  matrix1 
          = new int[m][k];
                  matrix2 
          = new int[k][n];
                  matrix3 
          = new int[m][n];
                  
                  fillRandom(matrix1);
                  fillRandom(matrix2);
                  startTime 
          = new Date().getTime();
                  
                  
          //輸出a,b
          //        printMatrix(matrix1);
          //        printMatrix(matrix2);
                  

                  
                  
          for(int task=0; task<m; task++)
                  
          {
                      System.out.println(
          "進程: "+Thread.currentThread().getName()+"\t開始計算第 "+(task+1)+"");
                      
          for(int i=0; i<n; i++)
                      
          {
                          
          for(int j=0; j<k; j++)
                          
          {
                              matrix3[task][i] 
          += matrix1[task][j] * matrix2[j][i];
                          }

                      }

                  }

                  
          //        printMatrix(matrix3);
                  long finishTime = new Date().getTime();
                  System.out.println(
          "計算完成,用時"+(finishTime-startTime)+"毫秒");
              }


              
          static void fillRandom(int[][] x)
              
          {
                  
          for (int i=0; i<x.length; i++)
                  
          {
                      
          for(int j=0; j<x[i].length; j++)
                      
          {
                          
          //每個元素設置為0到99的隨機自然數
                          x[i][j] = (int) (Math.random() * 100);
                      }

                  }

              }

          }


          修改m,n,k的值可以修改相乘矩陣的階數.

          結果對比,計算1024階矩陣的時候多線程用時約4.8秒,單線程用時16秒,
          單線程占用內存21M,多線程占用16M.
          本機是4核CPU,單線程的時候只有25%的CPU占用,使用4個子線程可以達到接近100%的CPU使用率.


          另外請教一個問題,是矩陣乘法的Strassen算法
          下面這個是來自網上的一段代碼,在我自己的電腦上,只要超過12階就會內存溢出
          不解是什么原因,設置jvm的內存不管多大也會崩潰在12階
          請高手幫忙解答....


          package matrix;

          import java.io.*;
          import java.util.*;

          class Matrix //定義矩陣結構    
          {
              
          public int[][] m = new int[32][32];
          }


          public class StrassenMatrix2
          {
              
          public int IfIsEven(int n)//判斷輸入矩陣階數是否為2^k    
              {
                  
          int a = 0, temp = n;
                  
          while (temp % 2 == 0)
                  
          {
                      
          if (temp % 2 == 0)
                          temp 
          /= 2;
                      
          else
                          a 
          = 1;
                  }

                  
          if (temp == 1)
                      a 
          = 0;
                  
          return a;
              }


              
          public void Divide(Matrix d, Matrix d11, Matrix d12, Matrix d21, Matrix d22, int n)//分解矩陣    
              {
                  
          int i, j;
                  
          for (i = 1; i <= n; i++)
                      
          for (j = 1; j <= n; j++)
                      
          {
                          d11.m[i][j] 
          = d.m[i][j];
                          d12.m[i][j] 
          = d.m[i][j + n];
                          d21.m[i][j] 
          = d.m[i + n][j];
                          d22.m[i][j] 
          = d.m[i + n][j + n];
                      }

              }


              
          public Matrix Merge(Matrix a11, Matrix a12, Matrix a21, Matrix a22, int n)//合并矩陣    
              {
                  
          int i, j;
                  Matrix a 
          = new Matrix();
                  
          for (i = 1; i <= n; i++)
                      
          for (j = 1; j <= n; j++)
                      
          {
                          a.m[i][j] 
          = a11.m[i][j];
                          a.m[i][j 
          + n] = a12.m[i][j];
                          a.m[i 
          + n][j] = a21.m[i][j];
                          a.m[i 
          + n][j + n] = a22.m[i][j];
                      }

                  
          return a;
              }


              
          public Matrix TwoMatrixMultiply(Matrix x, Matrix y) //階數為2的矩陣乘法    
              {
                  
          int m1, m2, m3, m4, m5, m6, m7;
                  Matrix z 
          = new Matrix();

                  m1 
          = (y.m[1][2- y.m[2][2]) * x.m[1][1];
                  m2 
          = y.m[2][2* (x.m[1][1+ x.m[1][2]);
                  m3 
          = (x.m[2][1+ x.m[2][2]) * y.m[1][1];
                  m4 
          = x.m[2][2* (y.m[2][1- y.m[1][1]);
                  m5 
          = (x.m[1][1+ x.m[2][2]) * (y.m[1][1+ y.m[2][2]);
                  m6 
          = (x.m[1][2- x.m[2][2]) * (y.m[2][1+ y.m[2][2]);
                  m7 
          = (x.m[1][1- x.m[2][1]) * (y.m[1][1+ y.m[1][2]);
                  z.m[
          1][1= m5 + m4 - m2 + m6;
                  z.m[
          1][2= m1 + m2;
                  z.m[
          2][1= m3 + m4;
                  z.m[
          2][2= m5 + m1 - m3 - m7;
                  
          return z;
              }


              
          public Matrix MatrixPlus(Matrix f, Matrix g, int n) //矩陣加法    
              {
                  
          int i, j;
                  Matrix h 
          = new Matrix();
                  
          for (i = 1; i <= n; i++)
                      
          for (j = 1; j <= n; j++)
                          h.m[i][j] 
          = f.m[i][j] + g.m[i][j];
                  
          return h;
              }


              
          public Matrix MatrixMinus(Matrix f, Matrix g, int n) //矩陣減法方法    
              {
                  
          int i, j;
                  Matrix h 
          = new Matrix();
                  
          for (i = 1; i <= n; i++)
                      
          for (j = 1; j <= n; j++)
                          h.m[i][j] 
          = f.m[i][j] - g.m[i][j];
                  
          return h;
              }


              
          public Matrix MatrixMultiply(Matrix a, Matrix b, int n) //矩陣乘法方法    
              {
                  
          int k;
                  Matrix a11, a12, a21, a22;
                  a11 
          = new Matrix();
                  a12 
          = new Matrix();
                  a21 
          = new Matrix();
                  a22 
          = new Matrix();
                  Matrix b11, b12, b21, b22;
                  b11 
          = new Matrix();
                  b12 
          = new Matrix();
                  b21 
          = new Matrix();
                  b22 
          = new Matrix();
                  Matrix c11, c12, c21, c22, c;
                  c11 
          = new Matrix();
                  c12 
          = new Matrix();
                  c21 
          = new Matrix();
                  c22 
          = new Matrix();
                  c 
          = new Matrix();
                  Matrix m1, m2, m3, m4, m5, m6, m7;
                  k 
          = n;
                  
          if (k == 2)
                  
          {
                      c 
          = TwoMatrixMultiply(a, b);
                      
          return c;
                  }
           else
                  
          {
                      k 
          = n / 2;
                      Divide(a, a11, a12, a21, a22, k); 
          //拆分A、B、C矩陣    
                      Divide(b, b11, b12, b21, b22, k);
                      Divide(c, c11, c12, c21, c22, k);

                      m1 
          = MatrixMultiply(a11, MatrixMinus(b12, b22, k), k);
                      m2 
          = MatrixMultiply(MatrixPlus(a11, a12, k), b22, k);
                      m3 
          = MatrixMultiply(MatrixPlus(a21, a22, k), b11, k);
                      m4 
          = MatrixMultiply(a22, MatrixMinus(b21, b11, k), k);
                      m5 
          = MatrixMultiply(MatrixPlus(a11, a22, k),
                              MatrixPlus(b11, b22, k), k);
                      m6 
          = MatrixMultiply(MatrixMinus(a12, a22, k),
                              MatrixPlus(b21, b22, k), k);
                      m7 
          = MatrixMultiply(MatrixMinus(a11, a21, k),
                              MatrixPlus(b11, b12, k), k);
                      c11 
          = MatrixPlus(MatrixMinus(MatrixPlus(m5, m4, k), m2, k), m6, k);
                      c12 
          = MatrixPlus(m1, m2, k);
                      c21 
          = MatrixPlus(m3, m4, k);
                      c22 
          = MatrixMinus(MatrixMinus(MatrixPlus(m5, m1, k), m3, k), m7, k);

                      c 
          = Merge(c11, c12, c21, c22, k); //合并C矩陣    
                      return c;
                  }

              }


              
          public Matrix GetMatrix(Matrix X, int n)
              
          {
                  
          int i, j;
                  X 
          = new Matrix();
                  
          for (i = 1; i <= n; i++)
                      
          for (j = 1; j <= n; j++)
                          X.m[i][j] 
          = (int) (Math.random() * 10);
                  
          for (i = 1; i <= n; i++)
                  
          {
                      
          for (j = 1; j <= n; j++)
                          System.out.print(X.m[i][j] 
          + " ");
                      System.out.println();
                  }

                  
          return X;
              }


              
          public Matrix UsualMatrixMultiply(Matrix A, Matrix B, Matrix C, int n)
              
          {
                  
          int i, j, t, k;
                  
          for (i = 1; i <= n; i++)
                      
          for (j = 1; j <= n; j++)
                      
          {
                          
          for (k = 1, t = 0; k <= n; k++)
                              t 
          += A.m[i][k] * B.m[k][j];
                          C.m[i][j] 
          = t;
                      }

                  
          return C;
              }


              
          public static void main(String[] args) throws IOException
              
          {
                  StrassenMatrix2 instance 
          = new StrassenMatrix2();
                  
          int i, j, n;
          //        Matrix A, B, C, D;
                  Matrix A, B, C;
                  A 
          = new Matrix();
                  B 
          = new Matrix();
                  C 
          = new Matrix();
          //        D = new matrix();
                  Scanner in = new Scanner(System.in);
                  System.out.print(
          "輸入矩陣的階數: ");
                  n 
          = in.nextInt();
                  
          if (instance.IfIsEven(n) == 0)
                  
          {
                      System.out.println(
          "矩陣A:");
                      A 
          = instance.GetMatrix(A, n);
                      System.out.println(
          "矩陣B:");
                      B 
          = instance.GetMatrix(B, n);
                      
          if (n == 1)
                          C.m[
          1][1= A.m[1][1* B.m[1][1]; //矩陣階數為1時的特殊處理     
                      else
                      
          {
                          
          long startTime = new Date().getTime();
                          C 
          = instance.MatrixMultiply(A, B, n);
                          
          long finishTime = new Date().getTime();
                          System.out.println(
          "計算完成,用時"+(finishTime-startTime)+"毫秒");
                      }

                      System.out.println(
          "Strassen矩陣C為:");
                      
          for (i = 1; i <= n; i++)
                      
          {
                          
          for (j = 1; j <= n; j++)
                              System.out.print(C.m[i][j] 
          + " ");
                          System.out.println();
                      }

                      
          /*            D = instance.UsualMatrixMultiply(A, B, D, n);
                      System.out.println("普通乘法矩陣D為:");
                      for (i = 1; i <= n; i++)
                      {
                          for (j = 1; j <= n; j++)
                              System.out.print(D.m[i][j] + " ");
                          System.out.println();
                      }
          */

                  }
           else
                      System.out.println(
          "輸入的階數不是2的N次方");
              }

          }
           
          posted on 2010-10-29 16:23 ApolloDeng 閱讀(4128) 評論(0)  編輯  收藏 所屬分類: 提問分享Java

          只有注冊用戶登錄后才能發表評論。


          網站導航:
           
          51La
          主站蜘蛛池模板: 巴里| 花莲县| 孝感市| 潮州市| 桂林市| 大邑县| 秦皇岛市| 松江区| 深水埗区| 新巴尔虎左旗| 辽宁省| 红河县| 柏乡县| 宿松县| 衡阳市| 手游| 招远市| 佛学| 兴城市| 常宁市| 平山县| 额尔古纳市| 宜都市| 曲松县| 洛扎县| 昌平区| 达日县| 镇安县| 富蕴县| 库伦旗| 杂多县| 麻江县| 衡阳市| 新乐市| 师宗县| 临湘市| 阿克苏市| 建宁县| 喜德县| 闽侯县| 临夏市|