Sort_2_Partition_划分算法

什么是 Partition 算法 ?

1.Partition 算法实现的是将一个数组一分为二的一个 [分割] 算法 或者叫 [划分] 算法; 首先选择数组中的一个元素, (可以随机找一个元素, 或者简单点指定第 1 个元素), 搞一次划分, 划分完的效果我们是想把比它小的所有元素都放在它左边, 比它大的元素都放在它右边; 通常我们定义 partition 算法的返回值是分割之后的枢轴在数组中的 index
2.特殊情况下, 如果作为分割元素的元素是存在重复的, 那么返回什么位置 ? 有两种均为正确的设定
(i). 可以返回正确分割结果的第 1 个位置, 也就是最左侧位置
(ii). 可以返回正确分割结果的最后 1 个位置, 也就是连续重复子数组最后 1 个位置
3.上述两种返回结果略有区别, 对应的实现代码在比较和交换的细节上也是略有区别, 但上述细节区别不是这个问题的关键

单次扫描 partition 算法实现

我们想下这个算法怎么实现呢 ?
1.先指定 1 个用来划分的元素, 比如数组里面第 1 个元素
2.然后扫一遍数组, 把每个扫到比枢轴 <= 的元素换到最靠左侧的区域位置
3.所以要维护下一个用于交换的位置, 比如设定这个位置为 pos, 扫到对于任何 <= 数轴的元素, pos 指针都往后移动 (指向下一个可以交换的位置)
4.最终将枢轴位置元素交换到 pos 位置, 并返回 pos

/*
单次扫描 Partition 算法
思想是扫描一遍数组, 把每个扫到比枢轴 <= 的元素换到最靠左侧的区域位置, 所以要维护下一个用于交换的位置
枢轴遇到重复元素, 返回的位置是最后 1 个允许出现的位置
函数输入区间设定为 [l, r] 全闭区间
*/
int partition(vector<int>& nums, int l, int r) {
  // 设定第 1 个元素为分割元素, 分割的元素通常命名为枢轴 pivot
  int pivot = nums[l];
  // 设定一个 pos 指针, pos指针 维护指向下一个可以放置 <= pivot 元素的位置, 找到了就可以执行交换
  int pos = l;
  // 扫描所有位置, 扫到对于任何 <= 数轴的元素, pos 指针都往后移动 (指向下一个可以交换的位置), 
  // 通过交换 pos 和 i 元素可以实现
  for (int i = l+1; i <= r; ++i) {
    // 当找到一个比 pivot 小或者相等的元素, 直接交换
    if (nums[i] <= pivot) {
      ++pos;
      swap(nums[pos], nums[i]);
    }
  }
  // 扫描完之后, 所有的 <= pivot 的元素都被换到了最前面
  // 且 pos 指向最后一个 <= pivot 的位置, 其实也就是取等号的情况, 得到 pivot 的位置
  // 然后我们把元素
  swap(nums[l], nums[pos]);
  return pos;
}

对于 [5,9,2,1,4,7,5,8,3,6] 这个数组, 我们推演下单次扫描 partition 算法的完整划分过程

pivot = 5              pos = 0
i = 1, cur = 9, n <= 5 pos = 0             [5,9,2,1,4,7,5,8,3,6] 
i = 2, cur = 2, y <= 5 pos = 1  switch 9-2 [5,2,9,1,4,7,5,8,3,6]
i = 3, cur = 1, y <= 5 pos = 2  switch 9-1 [5,2,1,9,4,7,5,8,3,6]
i = 4, cur = 4, n <= 5 pos = 3  switch 9-4 [5,2,1,4,9,7,5,8,3,6]
i = 5, cur = 7, n <= 5 pos = 3             [5,2,1,4,9,7,5,8,3,6]
i = 6, cur = 5, y <= 5 pos = 4  switch 9-5 [5,2,1,4,5,7,9,8,3,6] 
i = 7, cur = 8, n <= 5 pos = 4             [5,2,1,4,5,7,9,8,3,6] 
i = 8, cur = 3, y <= 5 pos = 5  switch 7-3 [5,2,1,4,5,3,9,8,7,6]  
i = 9, cur = 6, n <= 5 pos = 5             [5,2,1,4,5,3,9,8,7,6]
swap(nums[l], nums[pos]) final_pos=5       [3,2,1,4,5,5,9,8,7,6]
测试正确性并打印过程代码

#include 
#include 
using namespace std;
int partition(vector& nums, int l, int r) {
  int pivot = nums[l];
  int pos = l;
  for (int i = l+1; i <= r; ++i) {
    if (nums[i] <= pivot) {
      ++pos;
      swap(nums[pos], nums[i]);
    }
    std::cout << "i=" << i << ", pos=" << pos << " [";
    for (int i = 0; i <= r; ++i) {
      std::cout << nums[i] << " ";
    }
    std::cout << "]" << std::endl;
  }
  swap(nums[l], nums[pos]);
  cout << "     final:[";
  for (int i = 0; i <= r; ++i) {
    std::cout << nums[i] << " ";
  }
  std::cout << "]";
  return pos;
}
int main() {
  vector c{5,9,2,1,4,7,5,8,3,6} ;
  int r = partition(c, 0, c.size()-1);
  std::cout <<  "result_pos=" << r ;
  return 0;
}

每次交换的时候, 可以对一个同位置交换判断, 少交换一次

int partition(vector<int>& nums, int l, int r) {
  int pivot = nums[l];
  int pos = l;
  for (int i = l+1; i <= r; ++i) {
    if (nums[i] <= pivot) {
      ++pos;
      // 位置相同的时候, 不发生交换
      if (i != pos) {
        swap(nums[pos], nums[i]);
      }
    }
  }
  swap(nums[l], nums[pos]);
  return pos;
}

如果我们将比较条件由 nums[i] <= pivot 转化成 nums[i] < pivot, 也就是设定只有在严格小的时候才交换, 在这样的设定下, 如果有重复的枢轴元素, 结果返回的位置是元素最左侧合法的位置; 我们推演下单次扫描 partition 算法的完整划分过程

i=1, pos=0 [5 9 2 1 4 7 5 8 3 6 ]
i=2, pos=1 [5 2 9 1 4 7 5 8 3 6 ]
i=3, pos=2 [5 2 1 9 4 7 5 8 3 6 ]
i=4, pos=3 [5 2 1 4 9 7 5 8 3 6 ]
i=5, pos=3 [5 2 1 4 9 7 5 8 3 6 ]
i=6, pos=3 [5 2 1 4 9 7 5 8 3 6 ]
i=7, pos=3 [5 2 1 4 9 7 5 8 3 6 ]
i=8, pos=4 [5 2 1 4 3 7 5 8 9 6 ]
i=9, pos=4 [5 2 1 4 3 7 5 8 9 6 ]
     final:[3 2 1 4 5 7 5 8 9 6 ], result_pos=4

更一般地, 我们选择数组中随机位置上的元素作为枢轴; 相对上述选择第 1 个元素为枢轴的方式, 改动方式为: 选择完随机元素, 并换到第 1 个位置上, 我们就重复朴素的单次扫描算法进行划分

回顾下 cpp 中
要取 [a,b) 的随机整数, 用 (rand() % (b-a)) + a;
要取 [a,b] 的随机整数, 用 (rand() % (b-a+1)) + a;
要取 (a,b] 的随机整数, 用 (rand() % (b-a)) + a + 1;

得到单次扫描 partition 算法

int partition(vector<int>& nums, int l, int r) {
  int rIdx = (rand() % (r - l + 1)) + l;
  int pivot = nums[rIdx];
  swap(nums[rIdx], nums[l]);
  int pos = l;
  for (int i = l+1; i <= r; ++i) {
    if (nums[i] <= pivot) {
      ++pos;
      if (i != pos) {
        swap(nums[pos], nums[i]);
      }
    }
  }
  swap(nums[l], nums[pos]);
  return pos;
}
测试正确性并打印过程代码

#include 
#include 
#include 
#include 
using namespace std;
int partition(vector& nums, int l, int r) {
  int rIdx = (rand() % (r - l + 1)) + l;
  int pivot = nums[rIdx];
  std::cout << "rIdx=" << rIdx << ", pviot=" << pivot << std::endl;
  swap(nums[rIdx], nums[l]);
  int pos = l;
  for (int i = l+1; i <= r; ++i) {
    if (nums[i] <= pivot) {
      ++pos;
      if (i != pos) {
        swap(nums[pos], nums[i]);
      }
    }
    std::cout << "i=" << i << ", pos=" << pos << " [";
    for (int i = 0; i <= r; ++i) {
      std::cout << nums[i] << " ";
    }
    std::cout << "]" << std::endl;
  }
  swap(nums[l], nums[pos]);
  cout << "     final:[";
  for (int i = 0; i <= r; ++i) {
    std::cout << nums[i] << " ";
  }
  std::cout << "]";
  return pos;
}

int main() {
  srand(time(nullptr));
  vector c{5,9,2,1,4,7,5,8,3,6} ;
  int r = partition(c, 0, c.size()-1);
  std::cout << "result_pos=" << r ;
  return 0;
}

双指针扫描 Partition 算法-交换实现

1.单次扫描实现的 Partition, 相当于在扫描的过程中把更小的元素全部移动在最前面; 扫描的过程中, 如果是遇到了比 pivot 更大的元素, 这个比 pviot 更大的元素可能会发生多次交换
2.一个优化的方法是: 我们引入双指针的思想, 从左右两边各自设置一个指针同时扫描, 左边的指针扫描目标是比 pivot 严格更大的, 右边的指针扫描目标是比 pivot 严格更小的, 二者用交换的方式实现 [对撞] 扫描, 同时安排 2 个位置的效果
3.思考下一个边界条件, while () 括号里面写的 l <= r ? 还是 l < r ? 也就是双指针的碰撞是否允许发生?
取等号的时候, l == r, 设定了两个指针至多相遇 1 次, 这种相遇我们称之为 [对撞]; 我们可以设定这种相遇发生, 并继续让两个指针继续移动
4.交换整个过程结束之后, 我们要返回的中分的位置是哪个 ? 或者问为什么返回值是 r 而不是 l ?
相遇之后, l 继续向右走 1 步, 同时也是至多走 1 步, 走完之后此时关系为 l > r; r 此时再也没有任何移动的机会, 此时 r 正好停在数组内我们想要的位置, l 可能处于超出数组合法 index 的位置

int partition(vector<int>& nums, int left, int right) {
  int randIdx = left + rand() % (right - left + 1);
  swap(nums[randIdx], nums[left]);
  int pivot = nums[left];
  int l = left;
  int r = right;
  while (l <= r) {
    // 从左往右, 找到第 1 个严格 > pivot 的位置
    while (l <= r && nums[l] <= pivot) {
      ++l;
    }
    // 从右往左, 找到第 1 个严格 < pivot 的位置
    while (l <= r && nums[r] >= pivot) {
      --r;
    }
    if (l <= r) {
      // 交换值
      swap(nums[l], nums[r]);
      ++l;
      --r;
    }
  }
  swap(nums[left], nums[r]);
  return r;
}

我们测试并推演不同枢轴选择下的划分过程

ori                           [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=0 pivotNum=5
l=1, r=8 after possible swap: [5 3 2 1 4 7 5 8 9 6 ] l=2, r=7
l=5, r=4 after possible swap: [5 3 2 1 4 7 5 8 9 6 ] l=5, r=4
final                         [4 3 2 1 5 7 5 8 9 6 ] result_pos=4

ori                           [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=1 pivotNum=9
l=10, r=9 after possible swap:[9 5 2 1 4 7 5 8 3 6 ] l=10, r=9
final                         [6 5 2 1 4 7 5 8 3 9 ] result_pos=9

ori                           [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=2 pivotNum=2
l=1, r=3 after possible swap: [2 1 5 9 4 7 5 8 3 6 ] l=2, r=2
l=2, r=1 after possible swap: [2 1 5 9 4 7 5 8 3 6 ] l=2, r=1
final                         [1 2 5 9 4 7 5 8 3 6 ] result_pos=1

ori                           [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=3 pivotNum=1
l=1, r=0 after possible swap: [1 9 2 5 4 7 5 8 3 6 ] l=1, r=0
final                         [1 9 2 5 4 7 5 8 3 6 ] result_pos=0

ori                           [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=4 pivotNum=4
l=1, r=8 after possible swap: [4 3 2 1 5 7 5 8 9 6 ] l=2, r=7
l=4, r=3 after possible swap: [4 3 2 1 5 7 5 8 9 6 ] l=4, r=3
final                         [1 3 2 4 5 7 5 8 9 6 ] result_pos=3

ori                           [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=5 pivotNum=7
l=1, r=9 after possible swap: [7 6 2 1 4 5 5 8 3 9 ] l=2, r=8
l=7, r=8 after possible swap: [7 6 2 1 4 5 5 3 8 9 ] l=8, r=7
final                         [3 6 2 1 4 5 5 7 8 9 ] result_pos=7

ori                           [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=6 pivotNum=5
l=1, r=8 after possible swap: [5 3 2 1 4 7 5 8 9 6 ] l=2, r=7
l=5, r=4 after possible swap: [5 3 2 1 4 7 5 8 9 6 ] l=5, r=4
final                         [4 3 2 1 5 7 5 8 9 6 ] result_pos=4

ori                           [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=7 pivotNum=8
l=1, r=9 after possible swap: [8 6 2 1 4 7 5 5 3 9 ] l=2, r=8
l=9, r=8 after possible swap: [8 6 2 1 4 7 5 5 3 9 ] l=9, r=8
final                         [3 6 2 1 4 7 5 5 8 9 ] result_pos=8

ori                           [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=8 pivotNum=3
l=1, r=3 after possible swap: [3 1 2 9 4 7 5 8 5 6 ] l=2, r=2
l=3, r=2 after possible swap: [3 1 2 9 4 7 5 8 5 6 ] l=3, r=2
final                         [2 1 3 9 4 7 5 8 5 6 ] result_pos=2

ori                           [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=9 pivotNum=6
l=1, r=9 after possible swap: [6 5 2 1 4 7 5 8 3 9 ] l=2, r=8
l=5, r=8 after possible swap: [6 5 2 1 4 3 5 8 7 9 ] l=6, r=7
l=7, r=6 after possible swap: [6 5 2 1 4 3 5 8 7 9 ] l=7, r=6
final                         [5 5 2 1 4 3 6 8 7 9 ] result_pos=6
测试正确性并打印过程代码

#include 
#include 
#include 
#include 
#include 
using namespace std;
int partition(vector& nums, int left, int right, int pivotIdx) {
  // int randIdx = left + rand() % (right - left + 1);
  int randIdx = pivotIdx;
  swap(nums[randIdx], nums[left]);
  int pivot = nums[left];
  std::cout << "pivotIdx="  << pivotIdx << " pivotNum=" << pivot << std::endl;
  int l = left;
  int r = right;
  while (l <= r) {
    while (l <= r && nums[l] <= pivot) {
      ++l;
    }
    while (l <= r && nums[r] >= pivot) {
      --r;
    }
    std::cout << "l=" << l << ", r=" << r;
    if (l <= r) {
      swap(nums[l], nums[r]);
      ++l;
      --r;
    }
    std::cout << " after possible swap: [";
    for (int i = 0; i < nums.size(); ++i) {
      std::cout << nums[i] << " ";
    }
    std::cout << "]";
    std::cout << " l=" << l << ", r=" << r << std::endl;
  }
  swap(nums[left], nums[r]);
  cout << "final                         [";
  for (int i = 0; i < nums.size(); ++i) {
    std::cout << nums[i] << " ";
  }
  std::cout << "]";
  return r;
}
int main() {
  srand(time(nullptr));
  vector c{5,9,2,1,4,7,5,8,3,6};
  for (int k = 0; k < c.size(); ++k) {
    c = vector{5,9,2,1,4,7,5,8,3,6};
    cout << "ori                           [";
    for (int i = 0; i < c.size(); ++i) {
      std::cout << c[i] << " ";
    }
    std::cout << "]" << std::endl;
    int pos = partition(c, 0, c.size()-1, k);
    std::cout << " result_pos=" << pos << std::endl << std::endl;
  }
  return 0;
}

双指针扫描 Partition 算法-覆盖实现

1.分析上述的双指针对撞交换执行过程, 我们发现从两头到中间对撞的过程只发生一次, 同时两边的指针都是不走回头路的一遍遍历, 因此我们可以用覆盖的操作去替代交换的操作
2.然而, 和采用交换操作实现不同, 在双指针扫描覆盖算法实现过程中, 有一个 [初始枢轴位置选择] 和 [覆盖发生顺序] 的匹配问题:
(i). 如果初始时刻枢轴选择在最左边, 那么需要首先从右向左覆盖, 然后再从左往右覆盖, 然后重复这样的操作流程
(ii). 如果初始时刻枢轴选择在最右边, 那么需要首先从左向右覆盖, 然后再从右往左覆盖, 然后重复这样的操作流程
总结来说, 采用覆盖的方法的特点: 必须保证发生过程是严格正确且有序的
3.为什么有强制如上的设定, 必须匹配呢 ? 我们后面会分析一下错误的执行过程为什么会错误, 比如初始时刻枢轴选择在最左边, 先从左往右覆盖, 然后再从右往左覆盖会有什么问题
4.当我们选择 l < r 的边界设定下, 最后 while 终止的条件是 l == r, 因此, 最后一次覆盖写入的是 l 或者 r 都可以, 且返回位置用 l 或者 r 都可以
5.下面我们结合上述分析, 看下实现代码

初始枢轴选择左右边的情况: 先从右往左扫描写入, 然后再从左往右扫描写入

int partition(vector<int>& nums, int left, int right) {
  int randIdx = left + rand() % (right - left + 1);
  swap(nums[randIdx], nums[left]);
  // 初始枢轴选最左边
  int pivot = nums[left];
  int l = left;
  int r = right;
  while (l < r) {
    while (l < r && nums[r] >= pivot) {
      --r;
    }
    nums[l] = nums[r];
    while (l < r && nums[l] <= pivot) {
      ++l;
    }
    nums[r] = nums[l];
  }
  nums[l] = pivot;
  return l;
}

初始枢轴选择最右边的情况: 先从左往右扫描写入, 然后再从右往左扫描写入

int partition(vector<int>& nums, int left, int right) {
  int randIdx = left + rand() % (right - left + 1);
  swap(nums[randIdx], nums[right]);
  // 初始枢轴选最左边
  int pivot = nums[right];
  int l = left;
  int r = right;
  while (l < r) {
    while (l < r && nums[l] <= pivot) {
      ++l;
    }
    nums[r] = nums[l];
    while (l < r && nums[r] >= pivot) {
      --r;
    }
    nums[l] = nums[r];

  }
  nums[l] = pivot;
  return l;
}

我们测试并推演不同枢轴选择下的划分过程

ori                       [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=0 pivotNum=5
l=1, r=8 after overwrite  [3 9 2 1 4 7 5 8 9 6 ]
l=4, r=4 after overwrite  [3 4 2 1 4 7 5 8 9 6 ]
final                     [3 4 2 1 5 7 5 8 9 6 ] result_pos=4

ori                       [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=1 pivotNum=9
l=9, r=9 after overwrite  [6 5 2 1 4 7 5 8 3 6 ]
final                     [6 5 2 1 4 7 5 8 3 9 ] result_pos=9

ori                       [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=2 pivotNum=2
l=1, r=3 after overwrite  [1 9 5 9 4 7 5 8 3 6 ]
l=1, r=1 after overwrite  [1 9 5 9 4 7 5 8 3 6 ]
final                     [1 2 5 9 4 7 5 8 3 6 ] result_pos=1

ori                       [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=3 pivotNum=1
l=0, r=0 after overwrite  [1 9 2 5 4 7 5 8 3 6 ]
final                     [1 9 2 5 4 7 5 8 3 6 ] result_pos=0

ori                       [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=4 pivotNum=4
l=1, r=8 after overwrite  [3 9 2 1 5 7 5 8 9 6 ]
l=3, r=3 after overwrite  [3 1 2 1 5 7 5 8 9 6 ]
final                     [3 1 2 4 5 7 5 8 9 6 ] result_pos=3

ori                       [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=5 pivotNum=7
l=1, r=9 after overwrite  [6 9 2 1 4 5 5 8 3 9 ]
l=7, r=8 after overwrite  [6 3 2 1 4 5 5 8 8 9 ]
l=7, r=7 after overwrite  [6 3 2 1 4 5 5 8 8 9 ]
final                     [6 3 2 1 4 5 5 7 8 9 ] result_pos=7

ori                       [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=6 pivotNum=5
l=1, r=8 after overwrite  [3 9 2 1 4 7 5 8 9 6 ]
l=4, r=4 after overwrite  [3 4 2 1 4 7 5 8 9 6 ]
final                     [3 4 2 1 5 7 5 8 9 6 ] result_pos=4

ori                       [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=7 pivotNum=8
l=1, r=9 after overwrite  [6 9 2 1 4 7 5 5 3 9 ]
l=8, r=8 after overwrite  [6 3 2 1 4 7 5 5 3 9 ]
final                     [6 3 2 1 4 7 5 5 8 9 ] result_pos=8

ori                       [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=8 pivotNum=3
l=1, r=3 after overwrite  [1 9 2 9 4 7 5 8 5 6 ]
l=2, r=2 after overwrite  [1 2 2 9 4 7 5 8 5 6 ]
final                     [1 2 3 9 4 7 5 8 5 6 ] result_pos=2

ori                       [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=9 pivotNum=6
l=1, r=9 after overwrite  [5 9 2 1 4 7 5 8 3 9 ]
l=5, r=8 after overwrite  [5 3 2 1 4 7 5 8 7 9 ]
l=6, r=6 after overwrite  [5 3 2 1 4 5 5 8 7 9 ]
final                     [5 3 2 1 4 5 6 8 7 9 ] result_pos=6
测试正确性并打印过程代码

#include 
#include 
#include 
#include 
#include 
using namespace std;
int partition(vector& nums, int left, int right, int pivotIdx) {
  // int randIdx = left + rand() % (right - left + 1);
  int randIdx = pivotIdx;
  swap(nums[randIdx], nums[left]);
  int pivot = nums[left];
  std::cout << "pivotIdx="  << pivotIdx << " pivotNum=" << pivot << std::endl;
  int l = left;
  int r = right;
  while (l < r) {
    while (l < r && nums[r] >= pivot) {
      --r;
    }
    nums[l] = nums[r];
    while (l < r && nums[l] <= pivot) {
      ++l;
    }
    nums[r] = nums[l];
    std::cout << "l=" << l << ", r=" << r << " after overwrite  [";
    for (int i = 0; i < nums.size(); ++i) {
      std::cout << nums[i] << " ";
    }
    std::cout << "]" << std::endl;
  }
  nums[l] = pivot;
  cout << "final                     [";
  for (int i = 0; i < nums.size(); ++i) {
    std::cout << nums[i] << " ";
  }
  std::cout << "]";
  return r;
}
int main() {
  srand(time(nullptr));
  vector c{5,9,2,1,4,7,5,8,3,6};
  for (int k = 0; k < c.size(); ++k) {
    c = vector{5,9,2,1,4,7,5,8,3,6};
    cout << "ori                       [";
    for (int i = 0; i < c.size(); ++i) {
      std::cout << c[i] << " ";
    }
    std::cout << "]" << std::endl;
    int pos = partition(c, 0, c.size()-1, k);
    std::cout << " result_pos=" << pos << std::endl << std::endl;
  }
  return 0;
}

测试如果选择最左侧为枢轴, 且从左往右写入的这种错误的情况会发生什么?
我们拿枢轴 = 5 的情况可以分析下, 初始的时候 l = 0, r = 9; 从左往右扫描到元素 9 之后, l 指向元素 9, r 仍然是初始 idx 9, 这时候强行发生一次写入: nums[l] = nums[r] 此时的 nums[r] 还是初始值, 也就是说 r 还没有发生一次有意义的指向, 就开始写入, 这是错误的
保证初始枢轴的位置和写入顺序的匹配性, 本质是说, 每次让扫描后的指针有了具体的意义, 才执行写入:
(i). 当我们初始选择最左边为枢轴的时候, 第 1 次从右往左的扫描使得 r 有意义, 然后先执行写入 nums[l] = nums[r];
(ii). 当我们初始选择最右边为枢轴的时候, 第 1 次从左往右的扫描使得 l 有意义, 然后先执行写入 nums[r] = nums[l];

# 以下演示错误的执行过程发生了什么
ori                      [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=0 pivotNum=5
l=1, r=9 try   overwrite [5 9 2 1 4 7 5 8 3 9 ]
l=1, r=8 try   overwrite [5 3 2 1 4 7 5 8 3 9 ]
l=5, r=8 try   overwrite [5 3 2 1 4 7 5 8 7 9 ]
l=5, r=5 try   overwrite [5 3 2 1 4 7 5 8 7 9 ]
final                    [5 3 2 1 4 5 5 8 7 9 ] result_pos=5

ori                      [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=1 pivotNum=9
l=9, r=9 try   overwrite [9 5 2 1 4 7 5 8 3 6 ]
l=9, r=9 try   overwrite [9 5 2 1 4 7 5 8 3 6 ]
final                    [9 5 2 1 4 7 5 8 3 9 ] result_pos=9

ori                      [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=2 pivotNum=2
l=1, r=9 try   overwrite [2 9 5 1 4 7 5 8 3 9 ]
l=1, r=3 try   overwrite [2 1 5 1 4 7 5 8 3 9 ]
l=2, r=3 try   overwrite [2 1 5 5 4 7 5 8 3 9 ]
l=2, r=2 try   overwrite [2 1 5 5 4 7 5 8 3 9 ]
final                    [2 1 2 5 4 7 5 8 3 9 ] result_pos=2

ori                      [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=3 pivotNum=1
l=1, r=9 try   overwrite [1 9 2 5 4 7 5 8 3 9 ]
l=1, r=1 try   overwrite [1 9 2 5 4 7 5 8 3 9 ]
final                    [1 1 2 5 4 7 5 8 3 9 ] result_pos=1

ori                      [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=4 pivotNum=4
l=1, r=9 try   overwrite [4 9 2 1 5 7 5 8 3 9 ]
l=1, r=8 try   overwrite [4 3 2 1 5 7 5 8 3 9 ]
l=4, r=8 try   overwrite [4 3 2 1 5 7 5 8 5 9 ]
l=4, r=4 try   overwrite [4 3 2 1 5 7 5 8 5 9 ]
final                    [4 3 2 1 4 7 5 8 5 9 ] result_pos=4

ori                      [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=5 pivotNum=7
l=1, r=9 try   overwrite [7 9 2 1 4 5 5 8 3 9 ]
l=1, r=8 try   overwrite [7 3 2 1 4 5 5 8 3 9 ]
l=7, r=8 try   overwrite [7 3 2 1 4 5 5 8 8 9 ]
l=7, r=7 try   overwrite [7 3 2 1 4 5 5 8 8 9 ]
final                    [7 3 2 1 4 5 5 7 8 9 ] result_pos=7

ori                      [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=6 pivotNum=5
l=1, r=9 try   overwrite [5 9 2 1 4 7 5 8 3 9 ]
l=1, r=8 try   overwrite [5 3 2 1 4 7 5 8 3 9 ]
l=5, r=8 try   overwrite [5 3 2 1 4 7 5 8 7 9 ]
l=5, r=5 try   overwrite [5 3 2 1 4 7 5 8 7 9 ]
final                    [5 3 2 1 4 5 5 8 7 9 ] result_pos=5

ori                      [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=7 pivotNum=8
l=1, r=9 try   overwrite [8 9 2 1 4 7 5 5 3 9 ]
l=1, r=8 try   overwrite [8 3 2 1 4 7 5 5 3 9 ]
l=8, r=8 try   overwrite [8 3 2 1 4 7 5 5 3 9 ]
l=8, r=8 try   overwrite [8 3 2 1 4 7 5 5 3 9 ]
final                    [8 3 2 1 4 7 5 5 8 9 ] result_pos=8

ori                      [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=8 pivotNum=3
l=1, r=9 try   overwrite [3 9 2 1 4 7 5 8 5 9 ]
l=1, r=3 try   overwrite [3 1 2 1 4 7 5 8 5 9 ]
l=3, r=3 try   overwrite [3 1 2 1 4 7 5 8 5 9 ]
l=3, r=3 try   overwrite [3 1 2 1 4 7 5 8 5 9 ]
final                    [3 1 2 3 4 7 5 8 5 9 ] result_pos=3

ori                      [5 9 2 1 4 7 5 8 3 6 ]
pivotIdx=9 pivotNum=6
l=1, r=9 try   overwrite [6 9 2 1 4 7 5 8 3 9 ]
l=1, r=8 try   overwrite [6 3 2 1 4 7 5 8 3 9 ]
l=5, r=8 try   overwrite [6 3 2 1 4 7 5 8 7 9 ]
l=5, r=6 try   overwrite [6 3 2 1 4 5 5 8 7 9 ]
l=6, r=6 try   overwrite [6 3 2 1 4 5 5 8 7 9 ]
l=6, r=6 try   overwrite [6 3 2 1 4 5 5 8 7 9 ]
final                    [6 3 2 1 4 5 6 8 7 9 ] result_pos=6

测试错误逻辑代码 - 测试如果枢轴的设定位置和覆盖的顺序没有正确匹配

#include <iostream>
#include <vector>
#include <algorithm>
#include <ctime>
#include <cstdlib>
using namespace std;
int partition(vector<int>& nums, int left, int right, int pivotIdx) {
  // int randIdx = left + rand() % (right - left + 1);
  int randIdx = pivotIdx;
  swap(nums[randIdx], nums[left]);
  int pivot = nums[left];
  std::cout << "pivotIdx="  << pivotIdx << " pivotNum=" << pivot << std::endl;
  int l = left;
  int r = right;
  while (l < r) {
    while (l < r && nums[r] >= pivot) {
      --r;
    }
    nums[l] = nums[r];
    std::cout << "l=" << l << ", r=" << r << " after overwrite  [";
    for (int i = 0; i < nums.size(); ++i) {
      std::cout << nums[i] << " ";
    }
    std::cout << "]" << std::endl;
    while (l < r && nums[l] <= pivot) {
      ++l;
    }
    nums[r] = nums[l];
    std::cout << "l=" << l << ", r=" << r << " after overwrite  [";
    for (int i = 0; i < nums.size(); ++i) {
      std::cout << nums[i] << " ";
    }
    std::cout << "]" << std::endl;
  }
  nums[l] = pivot;
  cout << "final                     [";
  for (int i = 0; i < nums.size(); ++i) {
    std::cout << nums[i] << " ";
  }
  std::cout << "]";
  return l;
}
int main() {
  srand(time(nullptr));
  vector<int> c{5,9,2,1,4,7,5,8,3,6};
  for (int k = 0; k < c.size(); ++k) {
    c = vector<int>{5,9,2,1,4,7,5,8,3,6};
    cout << "ori                       [";
    for (int i = 0; i < c.size(); ++i) {
      std::cout << c[i] << " ";
    }
    std::cout << "]" << std::endl;
    int pos = partition(c, 0, c.size()-1, k);
    std::cout << " result_pos=" << pos << std::endl << std::endl;
  }
  return 0;
}

三向切分 Partition

1.上面的 partition 我们将数组分为了小于 target 和大于 target 两部分, 也称之为二分 partition; 有时候我们会期望将数组分为 小于 target, 等于 target, 大于 target 的三部分, 这就是三向切分 Partition
2.通常来说, 我们定义三向切分的 Partition 的返回值为两个分割点, 这两个分割点是对应和 target 相等的左端点位置和右端点位置
3.三向切分的 Partition 如何实现 ? 维护几个指针
(i). lt, 维护的下一个找到 < target 的元素之后应该换到的位置, 使得 [low, lt-1] 区间的元素都 < target
(ii). gt, 维护的下一个找到 > target 的元素之后应该换到的位置, 使得 [gt+1, high] 区间的元素都 > target
(iii). i 指针移动, 使得 [lt, gt] 区间的元素都 == target
4.在指针移动的时候, 仍然采用遍历并交换实现, 注意指针移动的时候有更新的细节
(i). nums[i] < target, swap(nums[i], nums[lt]), 完成交换后, i 指针 [持续] 往后移动, 因为 i 左边的元素始终是严格 <= target 的, lt 指针自增到下一个可以用于交换的位置等候; lt++, i++;
(ii). nums[i] > target, swap(nums[i], nums[lt]), 完成交换之后, gt向前移动到下一个发生交换的位置, gt—; i 位置不发生任何变化, 为什么 i 不变呢 ? 因为在 i 位置上, 我们找到一个比 target 更大的元素换过去了, 换回来的这个元素, 和 target 的大小关系还没有判断, 因此 i 不发生任何变化
(iii). nums[i] == target, 相等的情况下无需发生任何交换, 只需要考虑下一个元素, i++

vector<int> threeWayPartition(vector<int>& nums, int low, int high) {
  int pivot = nums[low];
  int lt = low;
  int gt = high;
  int i = low + 1;
  while (i <= gt) {
    if (nums[i] < pivot) {
      swap(nums[i++], nums[lt++]);
    } else if (nums[i] > pivot) {
      swap(nums[i], nums[gt--]);
    } else if (nums[i] == pivot) {
      ++i;
    }
  }
  return vector<int>{lt, gt};
}
测试正确性并打印过程代码

#include 
#include 
#include 
#include 
#include 
using namespace std;
vector threeWayPartition(vector& nums, int low, int high) {
  int pivot = nums[low];
  int lt = low;
  int gt = high;
  int i = low + 1;
  while (i <= gt) {
      std::cout << "i=" << i << ", lt=" << lt << ", gt=" << gt;
    if (nums[i] < pivot) {
      swap(nums[i++], nums[lt++]);
    } else if (nums[i] > pivot) {
      swap(nums[i], nums[gt--]);
    } else if (nums[i] == pivot) {
      ++i;
    }
      std::cout << " after swap: " << "i=" << i << ", lt=" << lt << ", gt=" << gt << " [";
    for (int i = 0; i < nums.size(); ++i) {
      std::cout << nums[i] << " ";
    }
    std::cout << "]" << std::endl;
  }
  return vector{lt, gt};
}
int main() {
  vector c{5,5,9,2,1,4,7,5,8,3,6};
    cout << "ori [";
    for (int i = 0; i < c.size(); ++i) {
        std::cout << c[i] << " ";
    }
    std::cout << "]" << std::endl;
    vector pos = threeWayPartition(c, 0, c.size()-1);
    cout << "aft [";
    for (int i = 0; i < c.size(); ++i) {
        std::cout << c[i] << " ";
    }
    std::cout << "]" << std::endl;
    std::cout << "lt_pos=" << pos[0] << " gt_pos=" << pos[1] << std::endl;
  return 0;
}

得到的结果

ori                                           [5 5 9 2 1 4 7 5 8 3 6 ]
i=1, lt=0, gt=10 after swap: i=2, lt=0, gt=10 [5 5 9 2 1 4 7 5 8 3 6 ]
i=2, lt=0, gt=10 after swap: i=2, lt=0, gt=9  [5 5 6 2 1 4 7 5 8 3 9 ]
i=2, lt=0, gt=9  after swap: i=2, lt=0, gt=8  [5 5 3 2 1 4 7 5 8 6 9 ] 下一次扫描到 3, 是一个比 5 小的数, 交换到 lt 的位置
i=2, lt=0, gt=8  after swap: i=3, lt=1, gt=8  [3 5 5 2 1 4 7 5 8 6 9 ] 下一次扫描到 2, 是一个比 5 小的数, 交换到 lt 的位置
i=3, lt=1, gt=8  after swap: i=4, lt=2, gt=8  [3 2 5 5 1 4 7 5 8 6 9 ]
i=4, lt=2, gt=8  after swap: i=5, lt=3, gt=8  [3 2 1 5 5 4 7 5 8 6 9 ]
i=5, lt=3, gt=8  after swap: i=6, lt=4, gt=8  [3 2 1 4 5 5 7 5 8 6 9 ]
i=6, lt=4, gt=8  after swap: i=6, lt=4, gt=7  [3 2 1 4 5 5 8 5 7 6 9 ]
i=6, lt=4, gt=7  after swap: i=6, lt=4, gt=6  [3 2 1 4 5 5 5 8 7 6 9 ]
i=6, lt=4, gt=6  after swap: i=7, lt=4, gt=6  [3 2 1 4 5 5 5 8 7 6 9 ]
aft                                           [3 2 1 4 5 5 5 8 7 6 9 ]
lt_pos=4 gt_pos=6

三方切分 Partition 思想解决荷兰国旗问题

75.颜色分类

给定一个包含红色、白色和蓝色、共 n 个元素的数组 nums ,原地对它们进行排序,使得相同颜色的元素相邻,并按照红色、白色、蓝色顺序排列。我们使用整数 0、 1 和 2 分别表示红色、白色和蓝色。 必须在不使用库内置的 sort 函数的情况下解决这个问题。

示例 1:输入:nums = [2,0,2,1,1,0] 输出:[0,0,1,1,2,2]

示例 2:输入:nums = [2,0,1] 输出:[0,1,2]

分析:
1.该问题为经典的 [荷兰国旗问题], 不使用排序去做
2.最基础的想法是计数排序, 扫一遍统计出0/1/2的个数, 然后再扫一遍按照元素数量去覆盖数组

class Solution {
 public:
  void sortColors(vector<int>& nums) {
    int len0 = 0;      
    int len1 = 0;      
    int len2 = 0;      
    for (auto& x : nums) {
      if (x == 0) {
        ++len0;
      }
      if (x == 1) {
        ++len1;
      }
      if (x == 2) {
        ++len2;
      }
    }
    int i = 0;
    while (i < nums.size()) {
      if (i < len0) {
        nums[i] = 0;
      } else if (i < len0 + len1) {
        nums[i] = 1;
      } else {
        nums[i] = 2;
      }
      ++i;
    }
  }
}

3.利用三向切分 partition 的思想去进行交换, 在三向切分 Partition里面是维护两个指针分别指向 < target 的位置和 > target 的位置, 在荷兰国旗问题里面类似是指向 < 1 和 > 1的位置

class Solution {
 public:
  void sortColors(vector<int>& nums) {
    int len = nums.size();
    int p0 = 0;       // 指向 0 的做开区间, 最终指向 1 的左端点
    int p2 = len-1;   // 指向 2 的左开区间, 最终指向 1 的右端点
    int i = 0;
    while (i <= p2) {
      if (nums[i] == 0) {
        swap(nums[i++], nums[p0++]);
      } else if (nums[i] == 2) {
        swap(nums[i], nums[p2--]);
      } else if (nums[i] == 1) {
        ++i;
      }
    }
    return;
  }
};

更优雅的版本: 一次扫描, 且用覆盖来代替交换
维护 p0 写入和 p1 写入点两个指针, 扫描数组, 对于 nums[i] 来说, 不管是几先暂定是2, 这个 2 后续会因为 p0 和 p1 可能发生的写入而覆盖掉, 然后根据值 num 再决定 p0 和 p1 的相应写入和移动过程
(i). num == 0 或者 num == 1 的情况下, p1 指针写入并移动
(ii). num == 0 的情况下 p0 指针写入并移动

class Solution {
 public:
  void sortColors(vector<int>& nums) {
    int p0 = 0; // p0 指向 0 的写入点
    int p1 = 0; // p1 指向 1 的写入点
    for (int i = 0; i < nums.size(); ++i) {
      int num = nums[i];
      // 不管是几先暂定写入2
      nums[i] = 2;
      // 如果 num < 2的情况下, p1指针写入并移动
      if (num < 2) {
        nums[p1++] = 1;
      }
      // 如果 num == 0 的情况下, p0指针写入并移动
      if (num == 0) {
        nums[p0++] = 0;
      }
    }
    return;
  }
};

Reference

[1]. 被忽视的 partition 算法. https://selfboot.cn/2016/09/01/lost_partition/.
[2]. 生成数组区间范围内随机数. https://blog.csdn.net/hellokandy/article/details/90045187.
[3]. 算法 第四版.


转载请注明来源, from goldandrabbit.github.io

💰

×

Help us with donation