52ky 发表于 2022-5-5 09:21:19

单链表 C++ 的快速选择算法

问题
我需要一种算法,它可以找到具有线性时间复杂度 O(n) 和恒定空间复杂度 O(1) 的单链表的中位数。

编辑:单链表是 C 风格的单链表。不允许使用 stl(没有容器,没有函数,禁止所有 stl,例如没有 std::forward_list)。不允许在任何其他容器(例如数组)中移动数字。

O(logn) 的空间复杂度是可以接受的,因为我的列表实际上小于 100。另外,我不允许使用像 nth_element 这样的 STL 函数

基本上,我有一个包含 3*10^6 元素的链表,我需要在 3 秒内得到中位数,所以我不能用排序算法对列表进行排序(它会是 O(nlogn),它可能需要 10-14 秒)。

我在网上做了一些搜索,发现 std::vector 与 quickselect 的中位数可以在 O(n) 和 O(1) 空间中找到(最坏的情况是 O(n^2),但很少看到),例如:

但是我找不到任何算法来为链表执行此操作。问题是,我可以使用数组索引随机访问向量,如果我想修改算法,复杂度会大得多,因为。例如,当我将枢轴索引更改为左时,我实际上需要遍历列表以获取新元素并走得更远(这将使我的列表至少 O(k n) 和一个大 k,甚至 O(n^ 2)…)。

编辑2:

我知道我有太多变量,但我一直在测试不同的东西,我仍然在我的代码中。 . .

我当前的代码:
#include <bits/stdc++.h>

using namespace std;

template <class T> class Node {
    public:
    T data;
    Node<T> *next;
};

template <class T> class List {
    public:
    Node<T> *first;
};

template <class T> T getMedianValue(List<T> & l) {
    Node<T> *crt,*pivot,*incpivot;
    int left, right, lung, idx, lungrel,lungrel2, left2, right2, aux, offset;
    pivot = l.first;
    crt = pivot->next;
    lung = 1;
//lung is the lenght of the linked list (yeah it's lenght in romanian...)
//lungrel and lungrel2 are the relative lenghts of the part of
//the list I am processing, e.g: 2 3 4 in a list with 1 2 3 4 5
    right = left = 0;
    while (crt != NULL) {
      if(crt->data < pivot->data){
            aux = pivot->data;
            pivot->data = crt->data;
            crt->data = pivot->next->data;
            pivot->next->data = aux;
            pivot = pivot->next;
            left++;
      }
      else right++;
       // cout<<crt->data<<endl;
      crt = crt->next;
      lung++;
    }
    if(right > left) offset = left;
//cout<<endl;
//cout<<pivot->data<<" "<<left<<" "<<right<<endl;
//printList(l);
//cout<<endl;
    lungrel = lung;
    incpivot = l.first;
   // offset = 0;
    while(left != right){
      //cout<<"parcurgere"<<endl;
      if(left > right){
            //cout<<endl;
            //printList(l);
            //cout<<endl;
            //cout<<"testleft "<<incpivot->data<<" "<<left<<" "<<right<<endl;
            crt = incpivot->next;
            pivot = incpivot;
            idx = offset;left2 = right2 = lungrel = 0;
            //cout<<idx<<endl;
            while(idx < left && crt!=NULL){
               if(pivot->data > crt->data){
                   //cout<<"1crt "<<crt->data<<endl;
                     aux = pivot->data;
                     pivot->data = crt->data;
                     crt->data = pivot->next->data;
                     pivot->next->data = aux;
                     pivot = pivot->next;
                     left2++;lungrel++;
                  }
                  else {
                      right2++;lungrel++;
                      //cout<<crt->data<<" "<<right2<<endl;
                  }
                  //cout<<crt->data<<endl;
                  crt = crt->next;
                  idx++;
             }
             left = left2 + offset;
             right = lung - left - 1;
             if(right > left) offset = left;
             //if(pivot->data == 18) return 18;
             //cout<<endl;
             //cout<<"l "<<pivot->data<<" "<<left<<" "<<right<<" "<<right2<<endl;
         //printList(l);
      }
      else if(left < right && pivot->next!=NULL){
            idx = left;left2 = right2 = 0;
            incpivot = pivot->next;offset++;left++;
            //cout<<endl;
            //printList(l);
            //cout<<endl;
            //cout<<"testright "<<incpivot->data<<" "<<left<<" "<<right<<endl;
            pivot = pivot->next;
            crt = pivot->next;
            lungrel2 = lungrel;
            lungrel = 0;
         // cout<<"p right"<<pivot->data<<" "<<left<<" "<<right<<endl;
            while((idx < lungrel2 + offset - 1) && crt!=NULL){
               if(crt->data < pivot->data){
                //   cout<<"crt "<<crt->data<<endl;
                     aux = pivot->data;
                     pivot->data = crt->data;
                     crt->data = (pivot->next)->data;
                     (pivot->next)->data = aux;
                     pivot = pivot->next;
               //    cout<<"crt2 "<<crt->data<<endl;
                     left2++;lungrel++;
                  }
                  else right2++;lungrel++;
                  //cout<<crt->data<<endl;
                  crt = crt->next;
                  idx++;
             }
             left = left2 + left;
             right = lung - left - 1;
               if(right > left) offset = left;
            // cout<<"r "<<pivot->data<<" "<<left<<" "<<right<<endl;
         //printList(l);
      }
      else{
            //cout<<cmx<<endl;
            return pivot->data;
      }
    }
    //cout<<cmx<<endl;
    return pivot->data;
}
template <class T> void printList(List<T> const & l) {
    Node<T> *tmp;
    if(l.first != NULL){
      tmp = l.first;
      while(tmp != NULL){
            cout<<tmp->data<<" ";
            tmp = tmp->next;
      }
    }
}
template <class T> void push_front(List<T> & l, int x)
{
    Node<T>* tmp = new Node<T>;

    tmp->data = x;

    tmp->next = l.first;
    l.first = tmp;
}

int main(){
    List<int> l;
    int n = 0;
    push_front(l, 19);
    push_front(l, 12);
    push_front(l, 11);
    push_front(l, 101);
    push_front(l, 91);
    push_front(l, 21);
    push_front(l, 9);
    push_front(l, 6);
    push_front(l, 25);
    push_front(l, 4);
    push_front(l, 18);
    push_front(l, 2);
    push_front(l, 8);
    push_front(l, 10);
    push_front(l, 200);
    push_front(l, 225);
    push_front(l, 170);
    printList(l);
    n=getMedianValue(l);
    cout<<endl;
    cout<<n;

    return 0;
}
您对如何使快速选择适应单独列出的链接或其他可以解决我的问题的算法有任何建议吗?

回答
在您的问题中,您提到将枢轴点向左移动时遇到问题,因为这需要遍历列表。如果你做对了,你只需要遍历整个列表两次:

如果你不太在意选择一个好的枢轴,只是喜欢选择列表中的第一个元素作为枢轴(如果数据已经排序,这将导致最坏情况 O(n^ 2)时间复杂度),不需要第一步。

如果第一次遍历链表的尾端,通过保持指向链表尾的指针,就不必再次遍历它来找到尾了。此外,如果您使用标准 Lomuto 分区方案(我不这样做,原因如下所述),您还必须在列表中保留两个指针,表示标准 Lomuto 分区方案的 i 和 j 索引.通过使用这些指针,您永远不必遍历列表来访问单个元素。

此外,如果您保留一个指向每个分区的中间和结尾的指针,那么当您稍后必须对其中一个分区进行排序时,您不必再次遍历该分区来找到中间和结尾。

我现在已经为链表创建了我自己的 QuickSelect 算法实现,我已经在下面发布了。

既然你说链表是单链表,不能升级成双链表,那我就不能用Hoare分区方案了,因为反向迭代单链表代价很大。所以我正在使用通常效率较低的 Lomuto 分区方案。

在使用 Lomuto 分区方案时,通常选择第一个元素或最后一个元素作为轴。然而,选择这两种方法中的任何一种都有一个缺点,即排序数据将导致算法的最坏情况时间复杂度为 O(n^2)。这可以通过根据“三的中值”选择轴来防止。规则,即从第一个元素、中间元素和最后一个元素的中值中选择轴。所以在我的实现中,我使用这个“三位数”。规则。

此外,Lomuto 分区方案通常会创建两个分区,一个用于小于轴的值,一个用于大于或等于轴的值。但是,如果所有值都相同,这将导致 O(n^2) 的最坏情况时间复杂度。所以,在我的实现中,我创建了三个分区,一个用于小于轴的值,一个用于大于轴的值,一个用于等于轴的值。

虽然这些措施并没有完全消除 O(n^2) 的最坏情况时间复杂度的可能性,但它们至少使其极不可能发生。

我遇到的一个重要问题是,对于偶数个元素,中位数被定义为两个“中间”元素的算术平均值。或“中间”元素。所以我不能简单地调用函数 find_nth_element 因为例如,如果元素的总数是 14,那么我将寻找第 7 和第 8 个最大的元素。这意味着我将不得不两次调用这样的函数,这将是低效的。所以我写了一个搜索“中位数”的函数。同时元素。尽管这使代码更加复杂,但与不必调用相同函数两次的优势相比,由于额外代码复杂性导致的性能损失应该是最小的。

请注意,虽然我的实现完全在 C++ 编译器上编译,但我不会将其称为教科书 C++ 代码,因为它禁止使用 C++ 标准模板库中的任何内容。所以我的代码是 C 代码和 C++ 代码的混合体。
#include <iostream>
#include <iomanip>
#include <cassert>

//remove the comment in the following line for extra debugging information
//#define PRINT_DEBUG_INFO

template <typename T>
struct Node
{
    T data;
    Node<T> *next;
};

// NOTE:
// The return type is not necessarily the same as the data type. The reason for this is
// that, for example, the data type "int" requires a "double" as a return type, so that
// the arithmetic mean of "3" and "6" returns "4.5".
// This function may require template specializations to handle overflow or wrapping.
template<typename T, typename U>
U arithmetic_mean( const T &first, const T &second )
{
    return ( static_cast<U>(first) + static_cast<U>(second) ) / 2;
}

//the main loop of the function find_median can be in one of the following three states
enum LoopState
{
    //we are looking for one median value
    LOOPSTATE_LOOKINGFORONE,

    //we are looking for two median values, and the returned median
    //will be the arithmetic mean of the two
    LOOPSTATE_LOOKINGFORTWO,

    //one of the median values has been found, but we are still searching for
    //the second one
    LOOPSTATE_FOUNDONE
};

template <
    typename T, //type of the data
    typename U//type of the return value
>
U find_median( Node<T> *list )
{
    //This variable points to the pointer to the first element of the current partition.
    //During the partition phase, the linked list will be broken and reassembled afterwards, so
    //the pointer this pointer points to will be nullptr until it is reassembled.
    Node<T> **pp_start = &list;

    //these pointer are maintained for accessing the middle of the list for selecting a pivot using
    //the "median-of-three" rule
    Node<T> *p_middle;
    Node<T> *p_end;

    //result is not defined if list is empty
    assert( *pp_start != nullptr );

    //in the main loop, this variable always holds the number of elements in the current partition
    int num_total = 1;

    // First, we must traverse the entire linked list in order to determine the number of elements,
    // in order to calculate k1 and k2. If it is odd, then the median is defined as the k'th smallest
    // element where k = n / 2. If the number of elements is even, then the median is defined as the
    // arithmetic mean of the k'th element and the (k+1)'th element.
    // We also set a pointer to the nodes in the middle and at the end, which will be required later
    // for selecting a pivot according to the "median-of-three" rule.
    p_middle = *pp_start;
    for ( p_end = *pp_start; p_end->next != nullptr; p_end = p_end->next )
    {
      num_total++;
      if ( num_total % 2 == 0 ) p_middle = p_middle->next;
    }   

    // find out whether we are looking for only one or two median values
    enum LoopState loop_state = num_total % 2 == 0 ? LOOPSTATE_LOOKINGFORTWO : LOOPSTATE_LOOKINGFORONE;

    //set k to the index of the middle element, or if there are two middle elements, to the left one
    int k = ( num_total - 1 ) / 2;

    // If we are looking for two median values, but we have only found one, then this variable will
    // hold the value of the one we found. Whether we have found one can be determined by the state of
    // the variable loop_state.
    T val_found;

    for (;;)
    {
      assert( *pp_start != nullptr );
      assert( p_middle!= nullptr );
      assert( p_end   != nullptr );
      assert( num_total != 0 );

      if ( num_total == 1 )
      {
            switch ( loop_state )
            {
            case LOOPSTATE_LOOKINGFORONE:
                return (*pp_start)->data;
            case LOOPSTATE_FOUNDONE:
                return arithmetic_mean<T,U>( val_found, (*pp_start)->data );
            default:
                assert( false ); //this should be unreachable
            }
      }

      //select the pivot according to the "median-of-three" rule
      T pivot;
      if ( (*pp_start)->data < p_middle->data )
      {
            if ( p_middle->data < p_end->data )
                pivot = p_middle->data;
            else if ( (*pp_start)->data < p_end->data )
                pivot = p_end->data;
            else
                pivot = (*pp_start)->data;
      }
      else
      {
            if ( (*pp_start)->data < p_end->data )
                pivot = (*pp_start)->data;
            else if ( p_middle->data < p_end->data )
                pivot = p_end->data;
            else
                pivot = p_middle->data;
      }


      // We will be dividing the current partition into 3 new partitions (less-than,
      // equal-to and greater-than) each represented as a linked list. Each list
      // requires a pointer to the start of the list and a pointer to the pointer at
      // the end of the list to write the address of new elements to. Also, when
      // traversing the lists, we need to keep a pointer to the middle of the list,
      // as this information will be required for selecting a new pivot in the next
      // iteration of the loop. The latter is not required for the equal-to partition,
      // as it would never be used.
      Node<T> *p_less    = nullptr, **pp_less_end    = &p_less,    **pp_less_middle    = &p_less;
      Node<T> *p_equal   = nullptr, **pp_equal_end   = &p_equal;
      Node<T> *p_greater = nullptr, **pp_greater_end = &p_greater, **pp_greater_middle = &p_greater;

      // These pointers are only used as a cache to the location of end node. Despite
      // their similar name, their function is very different to pp_less_end and
      // pp_greater_end.
      Node<T> *p_less_end    = nullptr;
      Node<T> *p_greater_end = nullptr;

      // counter for the number of elements in each partition
      int num_less = 0;
      int num_equal = 0;
      int num_greater = 0;

      // NOTE:
      // The following loop will temporarily split the linked list. It will be merged later.
      Node<T> *p_next_node = *pp_start;
      *pp_start = nullptr;
      for ( int a = 0; a < num_total; a++ )
      {
            assert( p_next_node != nullptr );

            Node<T> *p_current_node = p_next_node;
            p_next_node = p_next_node->next;

            if ( p_current_node->data < pivot )
            {
                //link node to pp_less
                assert( *pp_less_end == nullptr );
                *pp_less_end = p_current_node;
                pp_less_end = &p_current_node->next;
                p_current_node->next = nullptr;

                num_less++;
                if ( num_less % 2 == 0 )
                {
                  pp_less_middle = &(*pp_less_middle)->next;
                }

                //setting this variable is only done for caching purposes and is not
                //directly related to the logic of the other variable pp_less_end
                p_less_end = p_current_node;
            }
            else if ( p_current_node->data == pivot )
            {
                //link node to pp_equal
                assert( *pp_equal_end == nullptr );
                *pp_equal_end = p_current_node;
                pp_equal_end = &p_current_node->next;
                p_current_node->next = nullptr;

                num_equal++;
            }
            else
            {
                //link node to pp_greater
                assert( *pp_greater_end == nullptr );
                *pp_greater_end = p_current_node;
                pp_greater_end = &p_current_node->next;
                p_current_node->next = nullptr;

                num_greater++;
                if ( num_greater % 2 == 0 )
                {
                  pp_greater_middle = &(*pp_greater_middle)->next;
                }

                //setting this variable is only done for caching purposes and is not
                //directly related to the logic of the other variable pp_greater_end
                p_greater_end = p_current_node;
            }
      }

      assert( num_total == num_less + num_equal + num_greater );

#ifdef PRINT_DEBUG_INFO
      //when PRINT_DEBUG_INFO is defined, it will print the length of all partitions and their contents
      {
            Node<T> *p;
            std::cout << std::setfill( '0' );
            std::cout << "partition lengths: ";
            std::cout <<
                std::setw( 2 ) << num_less    << " " <<
                std::setw( 2 ) << num_equal   << " " <<
                std::setw( 2 ) << num_greater << " " <<
                std::setw( 2 ) << num_total   << "\n";
            std::cout << "less: ";
            for ( p = p_less; p != nullptr; p = p->next ) std::cout << p->data << " ";
            std::cout << "\nequal: ";
            for ( p = p_equal; p != nullptr; p = p->next ) std::cout << p->data << " ";
            std::cout << "\ngreater: ";
            for ( p = p_greater; p != nullptr; p = p->next ) std::cout << p->data << " ";
            std::cout << "\n\n" << std::flush;
      }
#endif

      //insert less-than partition into list
      assert( *pp_start == nullptr );
      *pp_start = p_less;

      //insert equal-to partition into list
      assert( *pp_less_end == nullptr );
      *pp_less_end = p_equal;

      //insert greater-than partition into list
      assert( *pp_equal_end == nullptr );
      *pp_equal_end = p_greater;

      //link list to previously cut off part
      assert( *pp_greater_end == nullptr );
      *pp_greater_end = p_next_node;

      //if less-than partition is large enough to hold both possible median values
      if ( k + 2 <= num_less )
      {
            //set the next iteration of the loop to process the less-than partition
            //pp_start is already set to the desired value
            p_middle = *pp_less_middle;
            p_end = p_less_end;
            num_total = num_less;
      }

      //else if less-than partition holds one of both possible median values
      else if ( k + 1 == num_less )
      {
            if ( loop_state == LOOPSTATE_LOOKINGFORTWO )
            {
                //the equal_to partition never needs sorting, because all members are already equal
                val_found = p_equal->data;
                loop_state = LOOPSTATE_FOUNDONE;
            }
            //set the next iteration of the loop to process the less-than partition
            //pp_start is already set to the desired value
            p_middle = *pp_less_middle;
            p_end = p_less_end;
            num_total = num_less;
      }

      //else if equal-to partition holds both possible median values
      else if ( k + 2 <= num_less + num_equal )
      {
            //the equal_to partition never needs sorting, because all members are already equal
            if ( loop_state == LOOPSTATE_FOUNDONE )
                return arithmetic_mean<T,U>( val_found, p_equal->data );
            return p_equal->data;
      }

      //else if equal-to partition holds one of both possible median values
      else if ( k + 1 == num_less + num_equal )
      {
            switch ( loop_state )
            {
            case LOOPSTATE_LOOKINGFORONE:
                return p_equal->data;
            case LOOPSTATE_LOOKINGFORTWO:
                val_found = p_equal->data;
                loop_state = LOOPSTATE_FOUNDONE;
                k = 0;
                //set the next iteration of the loop to process the greater-than partition
                pp_start = pp_equal_end;
                p_middle = *pp_greater_middle;
                p_end = p_greater_end;
                num_total = num_greater;
                break;
            case LOOPSTATE_FOUNDONE:
                return arithmetic_mean<T,U>( val_found, p_equal->data );
            }
      }

      //else both possible median values must be in the greater-than partition
      else
      {
            k = k - num_less - num_equal;

            //set the next iteration of the loop to process the greater-than partition
            pp_start = pp_equal_end;
            p_middle = *pp_greater_middle;
            p_end = p_greater_end;
            num_total = num_greater;
      }
    }
}


// NOTE:
// The following code is not part of the algorithm, but is only intended to test the algorithm

template <typename T>
class List
{
public:
    List() : first( nullptr ) {}

    // the following is required to abide by the rule of three/five/zero
    // see: https://en.cppreference.com/w/cpp/language/rule_of_three
    List( const List<T> & ) = delete;
    List( const List<T> && ) = delete;
    List<T>& operator=( List<T> & ) = delete;
    List<T>& operator=( List<T> && ) = delete;

    ~List()
    {
      Node<T> *p = first;

      while ( p != nullptr )
      {
            Node<T> *temp = p;
            p = p->next;
            delete temp;
      }
    }

    void push_front( int data )
    {
      Node<T>* tmp = new Node<T>;

      tmp->data = data;

      tmp->next = first;
      first = tmp;
    }

    //member variables
    Node<T> *first;
};

int main()
{
    List<int> l;

    int unsorted_data[] = { 6, 19, 3, 7, 4, 9, 8, 2, 18, 4, 10, 11, 12, 10 };

    //create singly-linked list
    for ( int i = 0; i < sizeof( unsorted_data ) / sizeof( *unsorted_data ); i++ ) l.push_front( unsorted_data );

    std::cout << "The median is: " << std::setprecision(10) << find_median<int,double>( l.first ) << std::endl;

    return 0;
}
我已经用一百万个随机生成的元素成功地测试了我的代码,它几乎立即找到了正确的中位数。



页: [1]
查看完整版本: 单链表 C++ 的快速选择算法