Genetic Alignment
In this section, you will learn about the following components in Spatial:
FSM
Branching
FIFO
File IO and text management
Note that a large collection of Spatial applications can be found here.
Overview
The Needleman-Wunsch (NW) algorithm is an algorithm used in bioinformatics to align protein or nucleotide sequences. It builds a scoring matrix based on two strings, and then backtraces through the score matrix to determine the alignment of minimum error. For more information on the algorithm’s history and implementations, visit the Wikipedia page (https://en.wikipedia.org/wiki/Needleman-Wunsch_algorithm). The image to the right (credit Wikipedia) demonstrates a rough overview of how the algorithm works.
The Smith-Waterman (SW) algorithm is an alternative to NW with slightly different scoring mechanisms. We will also discuss how to modify NW to create SW.
Basic implementation (NW)
import spatial.dsl._ @spatial object NW_alg extends SpatialApp { /* Needleman-Wunsch Genetic Alignment algorithm LETTER KEY: Scores Ptrs a = 0 T T C G T T C G c = 1 0 -1 -2 -3 -4 ... 0 ← ← ← ← ... g = 2 T -1 1 0 -1 -2 T ↑ ↖ ← ← ← t = 3 C -2 0 -1 1 0 C ↑ ↑ ↑ ↖ ← - = 4 G -3 -2 -2 0 2 G ↑ ↑ ↑ ↑ ↖ _ = 5 A -4 -3 -3 -1 1 A ↑ ↑ ↑ ↑ ↖ . . . . . . PTR KEY: ← = 0 = skipB ↑ = 1 = skipA ↖ = 2 = align */ // Create struct to hold score and ptr values @struct case class nw_tuple(score: Int16, ptr: Int16) def main(args: Array[String]): Unit = { // Convert chars to 8 bit integers val a = 'a'.to[Int8] val c = 'c'.to[Int8] val g = 'g'.to[Int8] val t = 't'.to[Int8] val d = '-'.to[Int8] val underscore = '_'.to[Int8] val dash = ArgIn[Int8] // Expose dash to FPGA so it can use this value setArg(dash,d) // Set parallelization val par_load = 16 val par_store = 16 val row_par = 2 (1 -> 1 -> 8) // Set up semantics for algorithm values val SKIPB = 0 val SKIPA = 1 val ALIGN = 2 val MATCH_SCORE = 1 val MISMATCH_SCORE = -1 val GAP_SCORE = -1 // Extract sequences from command line val seqa_string = args(0) //"tcgacgaaataggatgacagcacgttctcgtattagagggccgcggtacaaaccaaatgctgcggcgtacagggcacggggcgctgttcgggagatcgggggaatcgtggcgtgggtgattcgccggc" val seqb_string = args(1) //"ttcgagggcgcgtgtcgcggtccatcgacatgcccggtcggtgggacgtgggcgcctgatatagaggaatgcgattggaaggtcggacgggtcggcgagttgggcccggtgaatctgccatggtcgat" // Pass dimensions to FPGA val measured_length = seqa_string.length val length = ArgIn[Int] val lengthx2 = ArgIn[Int] setArg(length, measured_length) setArg(lengthx2, measured_length*2) // Allocate maximum size so FPGA can be statically synthesized val max_length = 512 assert(max_length >= length.value, "Cannot have string longer than 512 elements") // Convert strings to int arrays val seqa_bin = seqa_string.map{c => c.to[Int8] } val seqb_bin = seqb_string.map{c => c.to[Int8] } val seqa_dram_raw = DRAM[Int8](length) val seqb_dram_raw = DRAM[Int8](length) val seqa_dram_aligned = DRAM[Int8](lengthx2) val seqb_dram_aligned = DRAM[Int8](lengthx2) setMem(seqa_dram_raw, seqa_bin) setMem(seqb_dram_raw, seqb_bin) Accel{ // Create memories for raw data and aligned data val seqa_sram_raw = SRAM[Int8](max_length) val seqb_sram_raw = SRAM[Int8](max_length) val seqa_fifo_aligned = FIFO[Int8](max_length*2) val seqb_fifo_aligned = FIFO[Int8](max_length*2) // Load raw data seqa_sram_raw load seqa_dram_raw(0::length par par_load) seqb_sram_raw load seqb_dram_raw(0::length par par_load) // Allocate score matrix val score_matrix = SRAM[nw_tuple](max_length+1,max_length+1) // Build score matrix Foreach(length+1 by 1 par row_par){ r => // If running multiple rows in parallel, ensure a later row does not scan an element until previous row has populated it val this_body = r % row_par // Compute cost for each element in row Sequential.Foreach(-this_body until length+1 by 1) { c => val previous_result = Reg[nw_tuple] val update = if (r == 0) (nw_tuple(-c.as[Int16], 0)) else if (c == 0) (nw_tuple(-r.as[Int16], 1)) else { val match_score = mux(seqa_sram_raw(c-1) == seqb_sram_raw(r-1), MATCH_SCORE.to[Int16], MISMATCH_SCORE.to[Int16]) val from_top = score_matrix(r-1, c).score + GAP_SCORE val from_left = previous_result.score + GAP_SCORE val from_diag = score_matrix(r-1, c-1).score + match_score mux(from_left >= from_top && from_left >= from_diag, nw_tuple(from_left, SKIPB), mux(from_top >= from_diag, nw_tuple(from_top,SKIPA), nw_tuple(from_diag, ALIGN))) } previous_result := update // Predicated write to keep update in bounds if (c >= 0) {score_matrix(r,c) = update} } } // Prepare to read score matrix val b_addr = Reg[Int](0) val a_addr = Reg[Int](0) Parallel{b_addr := length; a_addr := length} val done_backtrack = Reg[Bit](false) // FSM state definitions val traverseState = 0 val padBothState = 1 val doneState = 2 FSM(0)(state => state != doneState) { state => if (state == traverseState) { // Take from A and B if (score_matrix(b_addr,a_addr).ptr == ALIGN.to[Int16]) { seqa_fifo_aligned.enq(seqa_sram_raw(a_addr-1), !done_backtrack) seqb_fifo_aligned.enq(seqb_sram_raw(b_addr-1), !done_backtrack) done_backtrack := b_addr == 1.to[Int] || a_addr == 1.to[Int] b_addr :-= 1 a_addr :-= 1 // Take from B, skip A } else if (score_matrix(b_addr,a_addr).ptr == SKIPA.to[Int16]) { seqb_fifo_aligned.enq(seqb_sram_raw(b_addr-1), !done_backtrack) seqa_fifo_aligned.enq(dash, !done_backtrack) done_backtrack := b_addr == 1.to[Int] b_addr :-= 1 // Take from A, skip B } else { seqa_fifo_aligned.enq(seqa_sram_raw(a_addr-1), !done_backtrack) seqb_fifo_aligned.enq(dash, !done_backtrack) done_backtrack := a_addr == 1.to[Int] a_addr :-= 1 } // Pad the rest of the result FIFOs } else if (state == padBothState) { seqa_fifo_aligned.enq(underscore, !seqa_fifo_aligned.isFull) // I think this FSM body either needs to be wrapped in a body or last enq needs to be masked or else we are full before FSM sees full seqb_fifo_aligned.enq(underscore, !seqb_fifo_aligned.isFull) } else {} } { state => // Determine next state based on a and b pointers mux(state == traverseState && ((b_addr == 0.to[Int]) || (a_addr == 0.to[Int])), padBothState, mux(seqa_fifo_aligned.isFull || seqb_fifo_aligned.isFull, doneState, state)) } // Store result Parallel{ Sequential{seqa_dram_aligned(0::length*2 par par_store) store seqa_fifo_aligned} Sequential{seqb_dram_aligned(0::length*2 par par_store) store seqb_fifo_aligned} } } // Inspect results val seqa_aligned_result = getMem(seqa_dram_aligned) val seqb_aligned_result = getMem(seqb_dram_aligned) val seqa_aligned_string = charArrayToString(seqa_aligned_result.map(_.to[U8])) val seqb_aligned_string = charArrayToString(seqb_aligned_result.map(_.to[U8])) // Assume algorithm worked if >75% match val matches = seqa_aligned_result.zip(seqb_aligned_result){(a,b) => if ((a == b) || (a == d) || (b == d)) 1 else 0}.reduce{_+_} val cksum = matches.to[Float] > 0.75.to[Float]*measured_length.to[Float]*2 println("Result A: " + seqa_aligned_string) println("Result B: " + seqb_aligned_string) println("Found " + matches + " matches out of " + measured_length*2 + " elements") println("PASS: " + cksum + " (NW)") assert(cksum) } }
This example shows how to write the NW algorithm in Spatial.
Here is a rough sketch of what the scores and pointers mean in this context.
Create a new struct type, which can be used the same way as any type
.
We must specify the max string length for the app.
Here we support parallelization along the rows. Because each entry in the score matrix requires input from its immediate neighbors from the left, top, and top-left elements in the matrix, we must ensure that when two rows are running in parallel, the lower row does not scan an element until the upper row has already populated its immediate top neighbor. This is why the counter bound starts at -this_body
. We also do not want to start computing an element until its neighbor to the left is complete, so we add the Sequential
annotation to this loop. The Spatial compiler will recognize this kind of loop carry dependency and will set the initiation interval properly, but it is explicitly annotated here to demonstrate the concept.
Here is an example of an FSM, which will run for a number of iterations that is data-dependent and statically unknown. In each iteration, we either move up (decrement b_addr
), left (decrement a_addr
), or diagonal (decrement both). The direction we take depends on the score and pointer at each element in the score_matrix
. Only when both addr pointers are at 0 have we traversed the whole matrix down to the origin.
This algorithm does not guarantee a perfect match between two strings, it just searches for the best alignment between the two, which could include mismatches at certain characters. Therefore, we roughly expect 75% of the sequence to match, and we call that a success.
To compile the app for a particular target, see the Targets page
basic implementation (sw)
@spatial object SW_alg extends SpatialApp { // Name SW conflicts with something in spade /* Smith-Waterman Genetic Alignment algorithm This is just like SW algorithm, except negative scores are capped at 0, backwards traversal starts at highest score from any element on the perimeter, and end when score is 0 [SIC] SW diagram LETTER KEY: Scores Ptrs a = 0 T T C G T T C G c = 1 0 -1 -2 -3 -4 ... 0 ← ← ← ← ... g = 2 T -1 1 0 -1 -2 T ↑ ↖ ← ← ← t = 3 C -2 0 -1 1 0 C ↑ ↑ ↑ ↖ ← - = 4 G -3 -2 -2 0 2 G ↑ ↑ ↑ ↑ ↖ _ = 5 A -4 -3 -3 -1 1 A ↑ ↑ ↑ ↑ ↖ . . . . . . PTR KEY: ← = 0 = skipB ↑ = 1 = skipA ↖ = 2 = align */ @struct case class sw_tuple(score: Int16, ptr: Int16) @struct case class entry_tuple(row: I32, col: I32, score: Int16) def main(args: Array[String]): Unit = { val a = 'a'.to[Int8] val c = 'c'.to[Int8] val g = 'g'.to[Int8] val t = 't'.to[Int8] val d = '-'.to[Int8] val dash = ArgIn[Int8] setArg(dash,d) val underscore = '_'.to[Int8] val par_load = 16 val par_store = 16 val row_par = 2 (1 -> 1 -> 8) val SKIPB = 0 val SKIPA = 1 val ALIGN = 2 val MATCH_SCORE = 2 val MISMATCH_SCORE = -1 val GAP_SCORE = -1 val seqa_string = args(0) val seqb_string = args(1) val measured_length = seqa_string.length val length = ArgIn[Int] val lengthx2 = ArgIn[Int] setArg(length, measured_length) setArg(lengthx2, measured_length*2) val max_length = 512 assert(max_length >= length.value, "Cannot have string longer than 512 elements") // TODO: Support c++ types with 2 bits in dram val seqa_bin = seqa_string.map{c => c.to[Int8] } val seqb_bin = seqb_string.map{c => c.to[Int8] } val seqa_dram_raw = DRAM[Int8](length) val seqb_dram_raw = DRAM[Int8](length) val seqa_dram_aligned = DRAM[Int8](lengthx2) val seqb_dram_aligned = DRAM[Int8](lengthx2) setMem(seqa_dram_raw, seqa_bin) setMem(seqb_dram_raw, seqb_bin) Accel{ val seqa_sram_raw = SRAM[Int8](max_length) val seqb_sram_raw = SRAM[Int8](max_length) val seqa_fifo_aligned = FIFO[Int8](max_length*2) val seqb_fifo_aligned = FIFO[Int8](max_length*2) seqa_sram_raw load seqa_dram_raw(0::length par par_load) seqb_sram_raw load seqb_dram_raw(0::length par par_load) val score_matrix = SRAM[sw_tuple](max_length+1,max_length+1) val entry_point = Reg[entry_tuple] // Build score matrix Reduce(entry_point)(length+1 by 1 par row_par){ r => val possible_entry_point = Reg[entry_tuple] val this_body = r % row_par Sequential.Foreach(-this_body until length+1 by 1) { c => // Bug #151, should be able to remove previous_result reg when fixed val previous_result = Reg[sw_tuple] val update = if (r == 0) (sw_tuple(0, 0)) else if (c == 0) (sw_tuple(0, 1)) else { val match_score = mux(seqa_sram_raw(c-1) == seqb_sram_raw(r-1), MATCH_SCORE.to[Int16], MISMATCH_SCORE.to[Int16]) val from_top = score_matrix(r-1, c).score + GAP_SCORE val from_left = previous_result.score + GAP_SCORE val from_diag = score_matrix(r-1, c-1).score + match_score mux(from_left >= from_top && from_left >= from_diag, sw_tuple(from_left, SKIPB), mux(from_top >= from_diag, sw_tuple(from_top,SKIPA), sw_tuple(from_diag, ALIGN))) } previous_result := update if ((c == length.value | r == length.value) && possible_entry_point.score < update.score) possible_entry_point := entry_tuple(r, c, update.score) if (c >= 0) {score_matrix(r,c) = sw_tuple(max(0, update.score),update.ptr)} } possible_entry_point }{(a,b) => mux(a.score > b.score, a, b)} val traverseState = 0 val padBothState = 1 val doneState = 2 val b_addr = Reg[Int](0) val a_addr = Reg[Int](0) Parallel{b_addr := entry_point.row; a_addr := entry_point.col} val done_backtrack = Reg[Bit](false) FSM(0){state => state != doneState }{ state => if (state == traverseState) { if (score_matrix(b_addr,a_addr).ptr == ALIGN.to[Int16]) { seqa_fifo_aligned.enq(seqa_sram_raw(a_addr-1), !done_backtrack) seqb_fifo_aligned.enq(seqb_sram_raw(b_addr-1), !done_backtrack) done_backtrack := b_addr == 1.to[Int] || a_addr == 1.to[Int] b_addr :-= 1 a_addr :-= 1 } else if (score_matrix(b_addr,a_addr).ptr == SKIPA.to[Int16]) { seqb_fifo_aligned.enq(seqb_sram_raw(b_addr-1), !done_backtrack) seqa_fifo_aligned.enq(dash, !done_backtrack) done_backtrack := b_addr == 1.to[Int] b_addr :-= 1 } else { seqa_fifo_aligned.enq(seqa_sram_raw(a_addr-1), !done_backtrack) seqb_fifo_aligned.enq(dash, !done_backtrack) done_backtrack := a_addr == 1.to[Int] a_addr :-= 1 } } else if (state == padBothState) { seqa_fifo_aligned.enq(underscore, !seqa_fifo_aligned.isFull) // I think this FSM body either needs to be wrapped in a body or last enq needs to be masked or else we are full before FSM sees full seqb_fifo_aligned.enq(underscore, !seqb_fifo_aligned.isFull) } else {} } { state => mux(state == traverseState && (score_matrix(b_addr,a_addr).score == 0.to[Int16]), doneState, state) } Parallel{ Sequential{seqa_dram_aligned(0::seqa_fifo_aligned.numel par par_store) store seqa_fifo_aligned} Sequential{seqb_dram_aligned(0::seqb_fifo_aligned.numel par par_store) store seqb_fifo_aligned} } } val seqa_aligned_result = getMem(seqa_dram_aligned) val seqb_aligned_result = getMem(seqb_dram_aligned) val seqa_aligned_string = charArrayToString(seqa_aligned_result.map(_.to[U8])) val seqb_aligned_string = charArrayToString(seqb_aligned_result.map(_.to[U8])) // Pass if >75% match val matches = seqa_aligned_result.zip(seqb_aligned_result){(a,b) => if ((a == b) || (a == d) || (b == d)) 1 else 0}.reduce{_+_} val cksum = matches.to[Float] > 0.75.to[Float]*measured_length.to[Float]*2 println("Result A: " + seqa_aligned_string) println("Result B: " + seqb_aligned_string) println("Found " + matches + " matches out of " + measured_length*2 + " elements") println("PASS: " + cksum + " (SW)") assert(cksum) } }
Here is an implementation of SW. In this variation, we do not allow negative scores and clip these values to 0. We also allow the decoder FSM to start at any point around the right and bottom edges of the score matrix, rather than forcing it to begin an the bottom right corner.