Genetic Alignment

In this section, you will learn about the following components in Spatial:

  • FSM

  • Branching

  • FIFO

  • File IO and text management

Note that a large collection of Spatial applications can be found here.

Overview

The Needleman-Wunsch (NW) algorithm is an algorithm used in bioinformatics to align protein or nucleotide sequences. It builds a scoring matrix based on two strings, and then backtraces through the score matrix to determine the alignment of minimum error. For more information on the algorithm’s history and implementations, visit the Wikipedia page (https://en.wikipedia.org/wiki/Needleman-Wunsch_algorithm). The image to the right (credit Wikipedia) demonstrates a rough overview of how the algorithm works.

The Smith-Waterman (SW) algorithm is an alternative to NW with slightly different scoring mechanisms. We will also discuss how to modify NW to create SW.

nw.png
 

Basic implementation (NW)

import spatial.dsl._

@spatial object NW_alg extends SpatialApp {

 /*
  
  Needleman-Wunsch Genetic Alignment algorithm                                                  
  
    LETTER KEY:         Scores                   Ptrs                                                                                                  
      a = 0                   T  T  C  G                T  T  C  G                                                                                                                          
      c = 1                0 -1 -2 -3 -4 ...         0  ←  ←  ←  ← ...                                                                                                        
      g = 2             T -1  1  0 -1 -2          T  ↑  ↖  ←  ←  ←                                                                                                                          
      t = 3             C -2  0 -1  1  0          C  ↑  ↑  ↑  ↖  ←                                                                                                                         
      - = 4             G -3 -2 -2  0  2          G  ↑  ↑  ↑  ↑  ↖                                                                                                                                  
      _ = 5             A -4 -3 -3 -1  1          A  ↑  ↑  ↑  ↑  ↖                                                                                                                                 
                           .                         .                                                                                                                        
                           .                         .                       
                           .                         .                       
                                                                                                           
    PTR KEY:                                                                                                                                                                                                      
      ← = 0 = skipB
      ↑ = 1 = skipA
      ↖ = 2 = align                                                                                      
                                                                                                           
                                                                                                           

                                                                                                           
 */

  // Create struct to hold score and ptr values
  @struct case class nw_tuple(score: Int16, ptr: Int16)

  def main(args: Array[String]): Unit = {

    // Convert chars to 8 bit integers
    val a = 'a'.to[Int8]
    val c = 'c'.to[Int8]
    val g = 'g'.to[Int8]
    val t = 't'.to[Int8]
    val d = '-'.to[Int8]
    val underscore = '_'.to[Int8]
    val dash = ArgIn[Int8]

    // Expose dash to FPGA so it can use this value
    setArg(dash,d)

    // Set parallelization
    val par_load = 16
    val par_store = 16
    val row_par = 2 (1 -> 1 -> 8)

    // Set up semantics for algorithm values
    val SKIPB = 0
    val SKIPA = 1
    val ALIGN = 2
    val MATCH_SCORE = 1
    val MISMATCH_SCORE = -1
    val GAP_SCORE = -1 

    // Extract sequences from command line
    val seqa_string = args(0) //"tcgacgaaataggatgacagcacgttctcgtattagagggccgcggtacaaaccaaatgctgcggcgtacagggcacggggcgctgttcgggagatcgggggaatcgtggcgtgggtgattcgccggc"
    val seqb_string = args(1) //"ttcgagggcgcgtgtcgcggtccatcgacatgcccggtcggtgggacgtgggcgcctgatatagaggaatgcgattggaaggtcggacgggtcggcgagttgggcccggtgaatctgccatggtcgat"

    // Pass dimensions to FPGA
    val measured_length = seqa_string.length
    val length = ArgIn[Int]
    val lengthx2 = ArgIn[Int]
    setArg(length, measured_length)
    setArg(lengthx2, measured_length*2)

    // Allocate maximum size so FPGA can be statically synthesized
    val max_length = 512
    assert(max_length >= length.value, "Cannot have string longer than 512 elements")

    // Convert strings to int arrays
    val seqa_bin = seqa_string.map{c => c.to[Int8] }
    val seqb_bin = seqb_string.map{c => c.to[Int8] }
    val seqa_dram_raw = DRAM[Int8](length)
    val seqb_dram_raw = DRAM[Int8](length)
    val seqa_dram_aligned = DRAM[Int8](lengthx2)
    val seqb_dram_aligned = DRAM[Int8](lengthx2)
    setMem(seqa_dram_raw, seqa_bin)
    setMem(seqb_dram_raw, seqb_bin)

    Accel{
      // Create memories for raw data and aligned data
      val seqa_sram_raw = SRAM[Int8](max_length)
      val seqb_sram_raw = SRAM[Int8](max_length)
      val seqa_fifo_aligned = FIFO[Int8](max_length*2)
      val seqb_fifo_aligned = FIFO[Int8](max_length*2)

      // Load raw data
      seqa_sram_raw load seqa_dram_raw(0::length par par_load)
      seqb_sram_raw load seqb_dram_raw(0::length par par_load)

      // Allocate score matrix
      val score_matrix = SRAM[nw_tuple](max_length+1,max_length+1)

      // Build score matrix
      Foreach(length+1 by 1 par row_par){ r =>

        // If running multiple rows in parallel, ensure a later row does not scan an element until previous row has populated it
        val this_body = r % row_par

        // Compute cost for each element in row
        Sequential.Foreach(-this_body until length+1 by 1) { c =>
          val previous_result = Reg[nw_tuple]
          val update = if (r == 0) (nw_tuple(-c.as[Int16], 0)) else if (c == 0) (nw_tuple(-r.as[Int16], 1)) else {
            val match_score = mux(seqa_sram_raw(c-1) == seqb_sram_raw(r-1), MATCH_SCORE.to[Int16], MISMATCH_SCORE.to[Int16])
            val from_top = score_matrix(r-1, c).score + GAP_SCORE
            val from_left = previous_result.score + GAP_SCORE
            val from_diag = score_matrix(r-1, c-1).score + match_score
            mux(from_left >= from_top && from_left >= from_diag, nw_tuple(from_left, SKIPB), mux(from_top >= from_diag, nw_tuple(from_top,SKIPA), nw_tuple(from_diag, ALIGN)))
          }
          previous_result := update

          // Predicated write to keep update in bounds
          if (c >= 0) {score_matrix(r,c) = update}
        }
      }

      // Prepare to read score matrix
      val b_addr = Reg[Int](0)
      val a_addr = Reg[Int](0)
      Parallel{b_addr := length; a_addr := length}
      val done_backtrack = Reg[Bit](false)

      // FSM state definitions
      val traverseState = 0
      val padBothState = 1
      val doneState = 2

      FSM(0)(state => state != doneState) { state =>
        if (state == traverseState) {
          // Take from A and B
          if (score_matrix(b_addr,a_addr).ptr == ALIGN.to[Int16]) {
            seqa_fifo_aligned.enq(seqa_sram_raw(a_addr-1), !done_backtrack)
            seqb_fifo_aligned.enq(seqb_sram_raw(b_addr-1), !done_backtrack)
            done_backtrack := b_addr == 1.to[Int] || a_addr == 1.to[Int]
            b_addr :-= 1
            a_addr :-= 1
          // Take from B, skip A
          } else if (score_matrix(b_addr,a_addr).ptr == SKIPA.to[Int16]) {
            seqb_fifo_aligned.enq(seqb_sram_raw(b_addr-1), !done_backtrack)  
            seqa_fifo_aligned.enq(dash, !done_backtrack)          
            done_backtrack := b_addr == 1.to[Int]
            b_addr :-= 1
          // Take from A, skip B
          } else {
            seqa_fifo_aligned.enq(seqa_sram_raw(a_addr-1), !done_backtrack)
            seqb_fifo_aligned.enq(dash, !done_backtrack)          
            done_backtrack := a_addr == 1.to[Int]
            a_addr :-= 1
          }
        // Pad the rest of the result FIFOs
        } else if (state == padBothState) {
          seqa_fifo_aligned.enq(underscore, !seqa_fifo_aligned.isFull) // I think this FSM body either needs to be wrapped in a body or last enq needs to be masked or else we are full before FSM sees full
          seqb_fifo_aligned.enq(underscore, !seqb_fifo_aligned.isFull)
        } else {}
      } { state => 
        // Determine next state based on a and b pointers
        mux(state == traverseState && ((b_addr == 0.to[Int]) || (a_addr == 0.to[Int])), padBothState, 
          mux(seqa_fifo_aligned.isFull || seqb_fifo_aligned.isFull, doneState, state))
      }

      // Store result
      Parallel{
        Sequential{seqa_dram_aligned(0::length*2 par par_store) store seqa_fifo_aligned}
        Sequential{seqb_dram_aligned(0::length*2 par par_store) store seqb_fifo_aligned}
      }

    }

    // Inspect results
    val seqa_aligned_result = getMem(seqa_dram_aligned)
    val seqb_aligned_result = getMem(seqb_dram_aligned)
    val seqa_aligned_string = charArrayToString(seqa_aligned_result.map(_.to[U8]))
    val seqb_aligned_string = charArrayToString(seqb_aligned_result.map(_.to[U8]))

    // Assume algorithm worked if >75% match
    val matches = seqa_aligned_result.zip(seqb_aligned_result){(a,b) => if ((a == b) || (a == d) || (b == d)) 1 else 0}.reduce{_+_}
    val cksum = matches.to[Float] > 0.75.to[Float]*measured_length.to[Float]*2

    println("Result A: " + seqa_aligned_string)
    println("Result B: " + seqb_aligned_string)

    println("Found " + matches + " matches out of " + measured_length*2 + " elements")
    println("PASS: " + cksum + " (NW)")
    assert(cksum)
  }
}

This example shows how to write the NW algorithm in Spatial.

Here is a rough sketch of what the scores and pointers mean in this context.

Create a new struct type, which can be used the same way as any type.

We must specify the max string length for the app.

Here we support parallelization along the rows. Because each entry in the score matrix requires input from its immediate neighbors from the left, top, and top-left elements in the matrix, we must ensure that when two rows are running in parallel, the lower row does not scan an element until the upper row has already populated its immediate top neighbor. This is why the counter bound starts at -this_body. We also do not want to start computing an element until its neighbor to the left is complete, so we add the Sequential annotation to this loop. The Spatial compiler will recognize this kind of loop carry dependency and will set the initiation interval properly, but it is explicitly annotated here to demonstrate the concept.

Here is an example of an FSM, which will run for a number of iterations that is data-dependent and statically unknown. In each iteration, we either move up (decrement b_addr), left (decrement a_addr), or diagonal (decrement both). The direction we take depends on the score and pointer at each element in the score_matrix. Only when both addr pointers are at 0 have we traversed the whole matrix down to the origin.

This algorithm does not guarantee a perfect match between two strings, it just searches for the best alignment between the two, which could include mismatches at certain characters. Therefore, we roughly expect 75% of the sequence to match, and we call that a success.

To compile the app for a particular target, see the Targets page

basic implementation (sw)

@spatial object SW_alg extends SpatialApp { // Name SW conflicts with something in spade

  /*

   Smith-Waterman Genetic Alignment algorithm

   This is just like SW algorithm, except negative scores are capped at 0, backwards traversal starts at highest score from any
      element on the perimeter, and end when score is 0


     [SIC] SW diagram
     LETTER KEY:         Scores                   Ptrs
       a = 0                   T  T  C  G                T  T  C  G
       c = 1                0 -1 -2 -3 -4 ...         0  ←  ←  ←  ← ...
       g = 2             T -1  1  0 -1 -2          T  ↑  ↖  ←  ←  ←
       t = 3             C -2  0 -1  1  0          C  ↑  ↑  ↑  ↖  ←
       - = 4             G -3 -2 -2  0  2          G  ↑  ↑  ↑  ↑  ↖
       _ = 5             A -4 -3 -3 -1  1          A  ↑  ↑  ↑  ↑  ↖
                            .                         .
                            .                         .
                            .                         .

     PTR KEY:
       ← = 0 = skipB
       ↑ = 1 = skipA
       ↖ = 2 = align
  */

  @struct case class sw_tuple(score: Int16, ptr: Int16)
  @struct case class entry_tuple(row: I32, col: I32, score: Int16)


  def main(args: Array[String]): Unit = {

    val a = 'a'.to[Int8]
    val c = 'c'.to[Int8]
    val g = 'g'.to[Int8]
    val t = 't'.to[Int8]
    val d = '-'.to[Int8]
    val dash = ArgIn[Int8]
    setArg(dash,d)
    val underscore = '_'.to[Int8]

    val par_load = 16
    val par_store = 16
    val row_par = 2 (1 -> 1 -> 8)

    val SKIPB = 0
    val SKIPA = 1
    val ALIGN = 2
    val MATCH_SCORE = 2
    val MISMATCH_SCORE = -1
    val GAP_SCORE = -1
    val seqa_string = args(0)
    val seqb_string = args(1)
    val measured_length = seqa_string.length
    val length = ArgIn[Int]
    val lengthx2 = ArgIn[Int]
    setArg(length, measured_length)
    setArg(lengthx2, measured_length*2)
    val max_length = 512
    assert(max_length >= length.value, "Cannot have string longer than 512 elements")

    // TODO: Support c++ types with 2 bits in dram
    val seqa_bin = seqa_string.map{c => c.to[Int8] }
    val seqb_bin = seqb_string.map{c => c.to[Int8] }

    val seqa_dram_raw = DRAM[Int8](length)
    val seqb_dram_raw = DRAM[Int8](length)
    val seqa_dram_aligned = DRAM[Int8](lengthx2)
    val seqb_dram_aligned = DRAM[Int8](lengthx2)
    setMem(seqa_dram_raw, seqa_bin)
    setMem(seqb_dram_raw, seqb_bin)

    Accel{
      val seqa_sram_raw = SRAM[Int8](max_length)
      val seqb_sram_raw = SRAM[Int8](max_length)
      val seqa_fifo_aligned = FIFO[Int8](max_length*2)
      val seqb_fifo_aligned = FIFO[Int8](max_length*2)

      seqa_sram_raw load seqa_dram_raw(0::length par par_load)
      seqb_sram_raw load seqb_dram_raw(0::length par par_load)

      val score_matrix = SRAM[sw_tuple](max_length+1,max_length+1)

      val entry_point = Reg[entry_tuple]
      // Build score matrix
      Reduce(entry_point)(length+1 by 1 par row_par){ r =>
        val possible_entry_point = Reg[entry_tuple]
        val this_body = r % row_par
        Sequential.Foreach(-this_body until length+1 by 1) { c => // Bug #151, should be able to remove previous_result reg when fixed
          val previous_result = Reg[sw_tuple]
          val update = if (r == 0) (sw_tuple(0, 0)) else if (c == 0) (sw_tuple(0, 1)) else {
            val match_score = mux(seqa_sram_raw(c-1) == seqb_sram_raw(r-1), MATCH_SCORE.to[Int16], MISMATCH_SCORE.to[Int16])
            val from_top = score_matrix(r-1, c).score + GAP_SCORE
            val from_left = previous_result.score + GAP_SCORE
            val from_diag = score_matrix(r-1, c-1).score + match_score
            mux(from_left >= from_top && from_left >= from_diag, sw_tuple(from_left, SKIPB), mux(from_top >= from_diag, sw_tuple(from_top,SKIPA), sw_tuple(from_diag, ALIGN)))
          }
          previous_result := update
          if ((c == length.value | r == length.value) && possible_entry_point.score < update.score) possible_entry_point := entry_tuple(r, c, update.score)
          if (c >= 0) {score_matrix(r,c) = sw_tuple(max(0, update.score),update.ptr)}
        }
        possible_entry_point
      }{(a,b) => mux(a.score > b.score, a, b)}

      val traverseState = 0
      val padBothState = 1
      val doneState = 2

      val b_addr = Reg[Int](0)
      val a_addr = Reg[Int](0)
      Parallel{b_addr := entry_point.row; a_addr := entry_point.col}
      val done_backtrack = Reg[Bit](false)
      FSM(0){state => state != doneState }{ state =>
        if (state == traverseState) {
          if (score_matrix(b_addr,a_addr).ptr == ALIGN.to[Int16]) {
            seqa_fifo_aligned.enq(seqa_sram_raw(a_addr-1), !done_backtrack)
            seqb_fifo_aligned.enq(seqb_sram_raw(b_addr-1), !done_backtrack)
            done_backtrack := b_addr == 1.to[Int] || a_addr == 1.to[Int]
            b_addr :-= 1
            a_addr :-= 1
          } else if (score_matrix(b_addr,a_addr).ptr == SKIPA.to[Int16]) {
            seqb_fifo_aligned.enq(seqb_sram_raw(b_addr-1), !done_backtrack)
            seqa_fifo_aligned.enq(dash, !done_backtrack)
            done_backtrack := b_addr == 1.to[Int]
            b_addr :-= 1
          } else {
            seqa_fifo_aligned.enq(seqa_sram_raw(a_addr-1), !done_backtrack)
            seqb_fifo_aligned.enq(dash, !done_backtrack)
            done_backtrack := a_addr == 1.to[Int]
            a_addr :-= 1
          }
        } else if (state == padBothState) {
          seqa_fifo_aligned.enq(underscore, !seqa_fifo_aligned.isFull) // I think this FSM body either needs to be wrapped in a body or last enq needs to be masked or else we are full before FSM sees full
          seqb_fifo_aligned.enq(underscore, !seqb_fifo_aligned.isFull)
        } else {}
      } { state =>
        mux(state == traverseState && (score_matrix(b_addr,a_addr).score == 0.to[Int16]), doneState, state)
      }

      Parallel{
        Sequential{seqa_dram_aligned(0::seqa_fifo_aligned.numel par par_store) store seqa_fifo_aligned}
        Sequential{seqb_dram_aligned(0::seqb_fifo_aligned.numel par par_store) store seqb_fifo_aligned}
      }

    }

    val seqa_aligned_result = getMem(seqa_dram_aligned)
    val seqb_aligned_result = getMem(seqb_dram_aligned)
    val seqa_aligned_string = charArrayToString(seqa_aligned_result.map(_.to[U8]))
    val seqb_aligned_string = charArrayToString(seqb_aligned_result.map(_.to[U8]))

    // Pass if >75% match
    val matches = seqa_aligned_result.zip(seqb_aligned_result){(a,b) => if ((a == b) || (a == d) || (b == d)) 1 else 0}.reduce{_+_}
    val cksum = matches.to[Float] > 0.75.to[Float]*measured_length.to[Float]*2

    println("Result A: " + seqa_aligned_string)
    println("Result B: " + seqb_aligned_string)
    println("Found " + matches + " matches out of " + measured_length*2 + " elements")
    println("PASS: " + cksum + " (SW)")
    assert(cksum)
  }
}

Here is an implementation of SW. In this variation, we do not allow negative scores and clip these values to 0. We also allow the decoder FSM to start at any point around the right and bottom edges of the score matrix, rather than forcing it to begin an the bottom right corner.