Mailund on the Internet

On Writing, Science, Programming and more

Skew algorithm in Python and Go

I’m currently supervising some BSc projects where they implement various string algorithms. Yesterday, I got an email where a student asked for a longer example of the Skew/DC3 algorithm. Naturally, I felt too lazy to make one, but this morning I figured that it wouldn’t be too hard to implement it in Python, and then he could explore it himself. Then, between two Zoom meetings, I translated it into Go as well, since that is what this group is using for their project.

It seems that most of my students these days are not as well versed in C as I had hoped, so maybe it was a mistake to use C as the language in my book. I feel like I might have to implement more of the algorithms in Python for my class in the future. If I am not too lazy.

The skew algorithm is for computing the suffix array of a string. If you have a string x, then you can imagine extracting all its suffixes x[i:] and sorting them. One way of representing these sorted suffixes is an array of indices, SA, such that x[SA[i]:] gives you the suffix that would be at index i in the sorted suffixes, or in other words, it is an array such that

for i in sa:
    print(x[i:])

would print all the suffixes in lexicographical order. This array is the suffix array, and it is used in a wide variety of string algorithms.

If you actually generate all the suffixes you are spending \( O(n^2) \) just on that (the sum of the lengths of suffixes is quadratic). If you then sort them on top of that, you can do it in \( O(n^2) \) with a radix sort or \( O(n^2\log n) \) with a comparison based sort (because comparing two strings is in \( O(n) \). So a straightforward approach is not a good idea, especially in bioinformatics where \(n\) can easily be a billion.

There are, however, several algorithms for building the suffix array in linear time, and the skew algorithms is one of them. It isn’t the fastest around (that is probably the SA-IS algorithm, but I will leave that for another day), but it is relatively simple.

It is a divide and conquer algorithm where you split the input into two sets of indices. Indices that are zero modulo 3, i % 3 == 0, call them SA3 (although they aren’t a suffix array yet) and indices there are i % 3 == 1 or i % 3 == 2. Call these SA12. Then, recursively, you sort the indices in SA12, so you have the suffixes that start at those indices in the right order. That is problem of size \( 2/3 n \), and if it is the only recursion we do, and it is, then the recurrence equation for the running time is \( T(n) = T(2/3n) + W(n) \) where \( W(n) \) is the work we do directly in the algorithm. If \(W(n)=O(n)\) you do linear time per call, and then the recurrence equation is \(T(n) = T(2/3n)+O(n) = O(n)\) because the recurrence is a geometric sum.

That is generally the case: if you have a recursive algorithm where you don’t use more than linear time per call, and you recurse on something that is a fixed fraction of the original size, you get something in \(O(n)\), and heaps of divide and conquer algorithms are based on this. Usually, you see \(T(n) = T(n/2)+O(n)\) where the recursive case is half the data, but it works for any fraction. For the skew algorithm, we need 2/3, and I’ll explain why a little later.

The algorithm computes a suffix array from a string, so we can’t recurse on a list of indices. Instead, we build another string from SA12, such that we can reconstruct SA12 from the new string’s suffix array. This bit is a little involved, and is the crux of the algorithm. It involves building a new alphabet from sorted triplets of letters. I’ll explain it shortly.

Anyway, recursively you sort SA12. Then we sort SA3 using SA12. If you want to know the order of two indices in SA3, you can compare their first letter, x[i] < x[j] and if they are different, you know which one goes first. If they are the same, then you need to compare x[i:] < x[j:] which is x[i]x[i+1:] < x[j]x[j+1:] (but since x[i] = x[j] that is the same as x[i+1:] < x[j+1:]—if the first characters are the same, then their order is determined by the rest of the suffixes. But if i % 3 == 0 and j % 3 == 0 then (i+1) % 3 == 1 and (j+1) % 3 == 1, so we already have those sorted. We can look them up in SA12.

Even better, though, we can think about sorting SA3 as a radix sort. There, you start from the back of the strings and keep bucket sorting until you get to the beginning. In the second-to-last sort, you have your suffixes sorted according to x[i+1:], i.e., according to everything except the first letter. Doing the radix sort will take \(O(n^2)\) time, but we can extract the second-to-last array directly from SA12 if those are sorted. It is the indices i in SA12 with i % 3 == 1, in the order they are in SA12. If you start from there, and do a single bucket sort, you get SA3. This only takes linear time for the bucket sort.

Then, with SA12 and SA3 sorted, you merge them, and you are done.

A Python implementation (where I haven’t explained all of the recursive case) looks like this:

def skew_rec(x : list[int], asize : int) -> list[int]:
    "Recursive skew SA construction algorithm."

    SA12 = [i for i in range(len(x)) if i % 3 != 0]
    
    SA12 = radix3(x, asize, SA12)
    new_alpha = collect_alphabet(x, SA12)
    if len(new_alpha) < len(SA12):
        # Recursively sort SA12.
        # Construct the u string and compute its suffix array
        u = build_u(x, new_alpha)
        # For the recursion, remember that the real alphabet has
        # two sentinels, so + 2
        sa_u = skew_rec(u, len(new_alpha) + 2)
        # Then map u's suffix array back to a sorted SA12
        m = len(sa_u) // 2
        SA12 = [u_idx(i, m) for i in sa_u if i != m]

    # Special case if the last index is class 0. Then the
    # following class 1 isn't there, but we should treat it
    # as the smallest string in the class.
    SA3 = ([len(x) - 1] if len(x) % 3 == 1 else []) + \
          [i - 1 for i in SA12 if i % 3 == 1]
    SA3 = bucket_sort(x, asize, SA3)
    return merge(x, SA12, SA3)

You construct SA12, do some recursive magic, then construct SA3 from the sorted SA12, and then you merge the two.

In the Python code my string is a list of integers. That’s because the recursive cases get new alphabets where I use integers, so I need that here. The asize argument is the size of the alphabet the string is over. If you want a wrapper that works with a string, it looks like this:

def skew(x : str) -> list[int]:
    "Skew algorithm for a string."
    # The skew_rec() function wants a list of integers,
    # so we convert the string in the first call.
    # It is only because of the safe_idx() hack that we
    # need to convert the string; without it, we could work
    # with both str and list[int], but the sentinel we generate
    # is int, and we have to compare letters, so all letters must
    # then be int.
    # I am assuming that the alphabet size is 256 here, although
    # of course it might not be. It is a simplification instead of
    # remapping the string.
    return skew_rec([ord(y) for y in x], 256)

In Go, the two functions look like this:

func getSA12(x []int) []int {
	SA12 := []int{}
	for i := 0; i < len(x); i++ {
		if i%3 != 0 {
			SA12 = append(SA12, i)
		}
	}
	return SA12
}

func getSA3(SA12 []int) []int {
	SA3 := []int{}
	if len(x)%3 == 1 {
		SA3 = append(SA3, len(x)-1)
	}
	for _, i := range SA12 {
		if i%3 == 1 {
			SA3 = append(SA3, i-1)
		}
	}
	return SA3
}

func skew(x []int, asize int) []int {
	SA12 := radix3(x, asize, getSA12(x))
	alpha := collectAlphabet(x, SA12)
	if len(alpha) < len(SA12) {
		// Build u and its SA.
		u := buildU(x, alpha)
		usa := skew(u, len(alpha)+2) // +2 for sentinels
		// Then map back to SA12 indices
		m := len(usa) / 2
		SA12 = []int{}
		for _, i := range usa {
			if i != m {
				SA12 = append(SA12, uidx(i, m))
			}
		}
	}
	SA3 := bucketSort(x, asize, getSA3(SA12), 0)
	return merge(x, SA12, SA3)
}

func str2int(x string) []int {
	i := []int{}
	for _, c := range x {
		i = append(i, int(c))
	}
	return i
}

func Skew(x string) []int {
	/*
			Skew algorithm for a string."
		    The skew() function wants a list of integers,
		    so we convert the string in the first call.
		    I am assuming that the alphabet size is 256 here, although
		    of course it might not be. It is a simplification instead of
		    remapping the string.
	*/
	return skew(str2int(x), 256)
}

Go back to the handling of SA12. Constructing the first list of indices with the list comprehension is clear enough. We get the indices that are one or two modulo three. The next step is radix3(). What is that doing? It sorts the suffixes we give it—the indices in SA12—according to their first three letters. It is a radix sort that only looks at a prefix of length 3, so it runs in linear time.

def safe_idx(x : list[int], i : int) -> int:
    "Hack to get zero if we index beyond the end."
    return 0 if i >= len(x) else x[i]

def symbcount(x : list[int], asize : int) -> list[int]:
    "Count how often we see each character in the alphabet."
    # This is what collections.Counter does, but we need the
    # alphabet to be sorted integers, so we do it manually.
    counts = [0] * asize
    for c in x:
        counts[c] += 1
    return counts

def cumsum(counts : list[int]) -> list[int]:
    "Compute the cumulative sum from the character count."
    res, acc = [0] * len(counts), 0
    for i, k in enumerate(counts):
        res[i] = acc
        acc += k
    return res

def bucket_sort(x : list[int], asize : int,
                idx : list[int], offset : int = 0) -> list[int]:
    "Sort indices in idx according to x[i + offset]."
    sort_symbs = [safe_idx(x, i + offset) for i in idx]
    counts = symbcount(sort_symbs, asize)
    buckets = cumsum(counts)
    out = [None] * len(idx)
    for i in idx:
        bucket = safe_idx(x, i + offset)
        out[buckets[bucket]] = i
        buckets[bucket] += 1
    return out

def radix3(x : list[int], asize : int, idx : list[int]) -> list[int]:
    "Sort indices in idx according to their first three letters in x."
    idx = bucket_sort(x, asize, idx, 2)
    idx = bucket_sort(x, asize, idx, 1)
    return bucket_sort(x, asize, idx)

The safe_idx() function uses a trick for correctly sorting triplets when you might index beyond the edge of the string. You want shorter strings to be smaller than longer string (that is how lexicographical order is defined), and if you reserve zero as a special character that can only go at indices beyond the string, you get that without any other special cases. The zero byte is practically never used for anything except as a termination sentinel, because that is how it is used in C, so we can reserve it for this purpose. I am also going to steal the number 1 as a special sentinel, and make sure that my alphabet starts at 2 when I construct it. That is something we see a little later, but it is the reason that there is the +2 in the alphabet size in the recursive calls in skew().

The code is a straightforward bucket sort, we apply it three times to get a radix sort for prefixes of length 3.

The Go implementation looks like this:

func safe_idx(x []int, i int) int {
	if i >= len(x) {
		return 0
	} else {
		return x[i]
	}
}

func symbcount(x []int, idx []int, offset int, asize int) []int {
	counts := make([]int, asize)
	for _, i := range idx {
		counts[safe_idx(x, i+offset)]++
	}
	return counts
}

func cumsum(counts []int) []int {
	res := make([]int, len(counts))
	acc := 0
	for i, k := range counts {
		res[i] = acc
		acc += k
	}
	return res
}

func bucketSort(x []int, asize int, idx []int, offset int) []int {
	counts := symbcount(x, idx, offset, asize)
	buckets := cumsum(counts)
	out := make([]int, len(idx))
	for _, i := range idx {
		bucket := safe_idx(x, i+offset)
		out[buckets[bucket]] = i
		buckets[bucket]++
	}
	return out
}

func radix3(x []int, asize int, idx []int) []int {
	idx = bucketSort(x, asize, idx, 2)
	idx = bucketSort(x, asize, idx, 1)
	return bucketSort(x, asize, idx, 0)
}

That gives us something closer to a sorted array, but since we don’t look past the first three letters, it won’t have handled strings that share prefixes of length three or more. Those could be in the wrong order.

We now take those radix3() sorted strings, and we build a new alphabet from them. We will take all the triplets (first three characters in these suffixes) and give them a number. That number, will be the order of the triplet when sorted. So, run through the sorted indices, pull out the corresponding triplets, and each time you see a new triplet, you give it a new number. Since the indices are sorted, we see the triplets in increasing order, so the map we construct will map triplets to numbers preserving the order.

In Python:

TRIPLET = tuple[int,int,int]
TRIPLET_DICT = dict[TRIPLET,int]

def triplet(x : list[int], i : int) -> TRIPLET:
    "Extract the triplet (x[i],x[i+1],x[i+2])."
    return (safe_idx(x, i), safe_idx(x, i + 1), safe_idx(x, i + 2))

def collect_alphabet(x : list[int], idx : list[int]) -> TRIPLET_DICT:
    "Map the triplets starting at idx to a new alphabet."
    # I use 0 for the terminal sentinel and 1 for the 
    # separator, so I start the alphabet at 2, thus the + 2 later.
    # I'm using a dictionary for the alphabet, but you can build
    # it more efficiently by looking at the previous triplet in the
    # sorted SA12. It won't affect the asymptotic running time,
    # though.
    alpha = {}
    for i in idx:
        trip = triplet(x, i)
        if trip not in alpha:
            alpha[trip] = len(alpha) + 2 # +2 to reserve sentinels
    return alpha

In Go:

type triplet = [3]int
type tripletMap = map[triplet]int

func trip(x []int, i int) triplet {
	return triplet{safe_idx(x, i), safe_idx(x, i+1), safe_idx(x, i+2)}
}

func collectAlphabet(x []int, idx []int) tripletMap {
	alpha := tripletMap{}
	for _, i := range idx {
		t := trip(x, i)
		if _, ok := alpha[t]; !ok {
			alpha[t] = len(alpha) + 2
		}
	}
	return alpha
}

After building the new alphabet, we can check its size. If its size is the same as the length of the SA12, then all the triplets were unique. If they are unique, and we have sorted the suffixes according to them, then we have completed the sorting. We only had a problem if some suffixes shared prefixes longer than three, where we might not had them in the right order, but if the triplets were unique, that didn’t happen. So if the alphabet size is the same as the length of SA12, we don’t need to recurse.

If we do have duplicated triplets, we need to sort SA12 recursively. We do this by constructing a new string, u, that represents all the strings in SA12, but has length \(2/3n\).

Conceptually, what we do is we take our string x and then we construct u to be

u = x[1:] # x[2:]

where # is a symbol that doesn’t appear in x (I use 1 for this symbol, which is why my alphabet numbers start at 2).

If x = mississippi then u is ississippi#ssissippi.

That string is longer than x, so of course we can’t recurse on that, but we don’t represent it that way. We take the triplets and consider those single characters, so we have

u = (iss)(iss)(ipp)(i00)#(ssi)(ssi)(ppi)

where the zeros are those paddings we got from safe_idx(). The triplets is what we represent with our new alphabet, so they have numbers that match the lexicographical order of the triplets:

(i00) => 2
(ipp) => 3
(iss) => 4
(ppi) => 5
(ssi) => 6

so the real representation is

u = 4432#665

(where # is represented by the number 1)

This string is shorter than x

x = mississippi
u = 4432#665

Its length is 2/3 the length of x. So, at least the problem size shrinks the way it is supposed to do.

Building u is a matter of running through SA12 twice, extracting those that are one modulo three first and then those that are two module tree second.

def build_u(x : list[int], alpha : TRIPLET_DICT) -> list[int]:
    "Construct u string, using 1 as central sentinel."
    # By putting the i % 3 == 1 indices first, we know that the central
    # sentinel will always be at len(u) // 2.
    return [ *(alpha[triplet(x, i)] for i in range(1, len(x), 3)),
           1,
             *(alpha[triplet(x, i)] for i in range(2, len(x), 3)) ]
func buildU(x []int, alpha tripletMap) []int {
	u := []int{}
	for i := 1; i < len(x); i += 3 {
		u = append(u, alpha[trip(x, i)])
	}
	u = append(u, 1)
	for i := 2; i < len(x); i += 3 {
		u = append(u, alpha[trip(x, i)])
	}
	return u
}

With u encoded like triplets in this way, the first part of u corresponds one-to-one with the suffixes in SA12 that are 1 modulo 3, while the second half of u corresponds to those that are 2 module 3. The # sentinel keeps the strings in the first part separated from those in the second half. It doesn’t appear elsewhere in the string, so when I see it, I know that I don’t need to read further along to determine which original string I’m looking at.

The first suffix of u, u[0:] is 4432#665 = (iss)(iss)(ipp)(i00)#… is the suffix x[1:]. The next suffix, u[1:] = 4432#… = (iss)(ipp)(i00) is x[4:]. Then I get u[2:] = 32#… = (ipp)(i00) = x[7:]. The suffixes in u before the sentinel, i < len(u)//2, are the suffixes in SA12 that are 1 modulo 3, and the suffixes i > len(u)//2 are those that are 2 module 3 (because we constructed u that way). With a little bit of arithmetic, we can map indices in u back to indices in x:

def u_idx(i : int, m : int) -> int:
    "Map indices in u back to indices in the original string."
    if i < m: return 1 + 3 * i
    else: return 2 + 3 * (i - m - 1)
func uidx(i int, m int) int {
	if i < m {
		return 1 + 3*i
	} else {
		return 2 + 3*(i-m-1)
	}
}

For this arithmetic to work, you have to put the i % 3 == 1 index first in u, but you could also build u by putting the i % 3 == 2 first and then the i % 3 == 1. It is just slightly move involved to know the midpoint in that case.

The reason that we do the radix sort before we construct u, even though it doesn’t look like we use it for creating u, is to get the order of the new alphabet to match the lexicographical order of the triplets. Because the order is ensured from the radix3() step, we have that if u[i:] < u[j:] then x[u_idx(i,m):] < x[u_idx(j,m):]. That’s the reason we needed this step.

We are almost done with the algorithm now. We have built SA12, potentially via the u string and the recursive call, and we have built SA3. So, we have two sets of indices, they are ordered, but we need to merge them into one list of all suffixes. That is the last merge() step in the algorithm, but it isn’t the traditional merge of integer lists. We need the lists sorted according to the suffixes, of course.

Merging does follow the pattern you see every time you write a normal merge of lists, though. It is just the comparison operation we need to work on.

The Python merge() looks like this:

def merge(x : list[int], SA12 : list[int], SA3 : list[int]) -> list[int]:
    "Merge the suffixes in sorted SA12 and SA3."
    # I'm using a dict here, but you can use a list with a little
    # arithmetic
    ISA = { SA12[i]: i for i in range(len(SA12)) }
    SA = []
    i, j = 0, 0
    while i < len(SA12) and j < len(SA3):
        if less(x, SA12[i], SA3[j], ISA):
            SA.append(SA12[i])
            i += 1
        else:
            SA.append(SA3[j])
            j += 1
    SA.extend(SA12[i:])
    SA.extend(SA3[j:])
    return SA

and the Go version looks like this:

func merge(x []int, SA12 []int, SA3 []int) []int {
	ISA := map[int]int{}
	for i := 0; i < len(SA12); i++ {
		ISA[SA12[i]] = i
	}
	SA := []int{}
	i, j := 0, 0
	for i < len(SA12) && j < len(SA3) {
		if less(x, SA12[i], SA3[j], ISA) {
			SA = append(SA, SA12[i])
			i++
		} else {
			SA = append(SA, SA3[j])
			j++
		}
	}
	SA = append(SA, SA12[i:]...)
	SA = append(SA, SA3[j:]...)
	return SA
}

The comparison is hidden in the less() function, and there is an ISA (inverse suffix array) dict/map that we use for less(). The ISA table is simple enough. It tells us for each element in SA12 at which position it sits. It is just the inverse of the array. I’ve used a dict/map, but you can easily put it in an array if you want. I did in my C code, of course.

The less() function must compare two suffixes to determine which is smallest, and it has to do it in constant time if the merge() function is to run in linear time (and if it doesn’t, then skew() doesn’t either). That means that we have to exploit the information we have in SA12 and SA3.

Here’s the rule: if we compare i and j and x[i] != x[j] we can immediately determine which is smaller. If x[i] == x[j], however, we need help. If i % 3 != 0 and j % 3 != 0, then both i and j are in SA12, and then ISA[i] < ISA[j] tells me their order. If one of i or j is in SA3 and the other is in SA12, then I cannot use ISA, however. But if i % 3 == 0 and j % 3 == 1, then (i + 1) % 3 == 1 and (j + 1) == 2, and those are in SA12, and a recursive call to less() will determine the comparison immediately. Finally, if i % 3 == 0 and j % 3 == 2, then (i + 1) % 3 == 1 and (j + 1) % 3 == 0 which switches the arrays, but won’t help me with ISA. However, that takes me to a case in the recursion where one more call will put both indices in SA12.

It is because of the potential switch from i in SA3 and j in SA12 to i in SA12 and j in SA12 that we split the data in two unequal sizes. If we split the data in those at even and those at odd indices, for example, the less comparison could keep switching. If the data has unequal length, the indices will eventually end up in the same array. The smallest split in uneven sizes is two-to-one, and that is what the algorithm uses.

The less() function looks like this:

def less(x : list[int], i : int, j : int, ISA : dict[int,int]) -> bool:
    "Check if x[i:] < x[j:] using the inverse suffix array for SA12."
    a, b = safe_idx(x, i), safe_idx(x, j)
    if a < b: return True
    if a > b: return False
    if i % 3 != 0 and j % 3 != 0: return ISA[i] < ISA[j]
    return less(x, i + 1, j + 1, ISA)

or like this:

func less(x []int, i int, j int, ISA map[int]int) bool {
	a, b := safe_idx(x, i), safe_idx(x, j)
	if a < b {
		return true
	}
	if a > b {
		return false
	}
	if i%3 != 0 && j%3 != 0 {
		return ISA[i] < ISA[j]
	}
	return less(x, i+1, j+1, ISA)
}

I might have to take down this post later, if I ask my students to implement this algorithm… but for now, for your reading pleasure, here it is.