前回のChiselの記事では久しぶりにChisel-bootcampの学習に戻りScalaの高階関数についてを学習した。
今日も引き続き高階関数の章でもう少し例題の確認と練習問題に取り組んでいく。
高階関数
mapの練習問題から再開する。
練習問題:map
問題文は以下になる。
// 入力のリストを倍にするように`???`を埋めよう // これは次の値を返すはずだ: List(2, 4, 6, 8) println(List(1, 2, 3, 4).map(???))
解答 - クリックすると開くので、見たくない場合は開かないように注意。
println(List(1, 2, 3, 4).map(_ * 2)
例題:zipWithIndex
zipWithIndex
の定義はzipWithIndex: List[(A, Int)]
となっており、引数を取らないがリストの各要素にインデックスを付与したタプルのリストを返すという処理を行う。
言葉では若干わかりづらいが、実行してみれば一目瞭然だ。・
println(List(1, 2, 3, 4).zipWithIndex) // インデックスは0からスタートする println(List("a", "b", "c", "d").zipWithIndex) // タプルがネストするようなケースも可能 println(List(("a", "b"), ("c", "d"), ("e", "f"), ("g", "h")).zipWithIndex)
実行すると以下のようになる。
List((1,0), (2,1), (3,2), (4,3)) List((a,0), (b,1), (c,2), (d,3)) List(((a,b),0), ((c,d),1), ((e,f),2), ((g,h),3))
pythonのenumerateの処理に似てるといえば似てる。
例題:reduce
これはすでに使っているが、リストの各要素同士で処理を適用するという処理になる。定義はreduce(op: (A, A) ⇒ A): A
。
println(List(1, 2, 3, 4).reduce((a, b) => a + b)) // returns the sum of all the elements println(List(1, 2, 3, 4).reduce(_ * _)) // returns the product of all the elements println(List(1, 2, 3, 4).map(_ + 1).reduce(_ + _)) // you can chain reduce onto the result of a map
結果はListの中身であるInt
型がreduce
内部の処理で計算されてInt
型として返ってくる。
10 24 14
なおreduce
はリストの中身が空の場合は以下のようなエラーが発生する。
java.lang.UnsupportedOperationException: empty.reduceLeft scala.collection.LinearSeqOptimized$class.reduceLeft(LinearSeqOptimized.scala:137) scala.collection.immutable.List.reduceLeft(List.scala:84) scala.collection.TraversableOnce$class.reduce(TraversableOnce.scala:208) scala.collection.AbstractTraversable.reduce(Traversable.scala:104) ammonite.$sess.cmd7$Helper.<init>(cmd7.sc:1) ammonite.$sess.cmd7$.<init>(cmd7.sc:7) ammonite.$sess.cmd7$.<clinit>(cmd7.sc:-1)
練習問題:reduce
問題文はこちら。
// 入力されたリストの各要素を2倍した値同士で積を取るように???を埋めよう // これは次の値を返すはずだ: (1*2)*(2*2)*(3*2)*(4*2) = 384 println(List(1, 2, 3, 4).map(???).reduce(???))
解答 - クリックすると開くので、見たくない場合は開かないように注意。
println(List(1, 2, 3, 4).map(_ * 2).reduce(_ * _))
例題:fold
fold
はreduce
と似ている処理だ。定義はfold(z: A)(op: (A, A) ⇒ A): A
となっておりreduce
との違いはop
の前にz
が追加されていることだ。このz
はリストの最初に追加される要素となる。言い換えると初期値付きのrecude
と言ったところか。
println(List(1, 2, 3, 4).fold(0)(_ + _)) // equivalent to the sum using reduce println(List(1, 2, 3, 4).fold(1)(_ + _)) // like above, but accumulation starts at 1 println(List().fold(1)(_ + _)) // unlike reduce, does not fail on an empty input
実行するとすべてただのInt
が返却される。
10 11 1
1つ目と2つ目はz
に入れる値が異なるため結果が1
だけずれているのがわかると思う。またreduce
とは異なり初期値を持たせることができるため、3つ目の結果の様に空のリストを入力してもエラーとならずに初期値が返却される。
練習問題:fold
問題文は以下。
// 入力リストの各要素の積を2倍にした結果が得られるように???を埋めよう // このリストを与えると結果は48になるはずだ。 println(List(1, 2, 3, 4).fold(2)(_ * _))
解答 - クリックすると開くので、見たくない場合は開かないように注意。
println(List(1, 2, 3, 4).fold(2)(_ * _))
練習問題:Decoupledを使ったアービター
ここでChiselに戻って練習問題に取り組んでみよう。練習問題を訳したものを載せておく。
ここまでのすべて使った練習問題だ。この例題ではDecoupledアービターを作成する。 その仕様は以下のようなものだ:
- n 個のDecoupldedの入力を持ち、1つのDecoupledの出力を持つ。
- この時出力されるDecoupledはvalid信号を上げたうち一番小さいチャネルとなる。
いくつかヒントを与えておこう:
- アーキテクチャ面からのヒント
- 次の構造が助けになる:
- 入力の各要素を返却するような
map
:例えばio.in.map(_.valid)
は入力のBundleのvalid
を返却するPriorityMux(List[Bits, Bool])
:bits
データとvalid
信号のリストを引数にして、valid
がたった先頭要素を返却する- Vecの動的インデックス。インデックスは
UInt
で指定する:例えばio.in(0.U)
ちなみに上記のヒントにはPriorityMux(List[Bits, Bool])
が書いてあるが、PriorityMux
には以下の様に3つの使い方が用意されておりList[Bits, Bool]
を使うよりも、違う形態を使ったほうがもっとシンプルに書ける。
/** Builds a Mux tree under the assumption that multiple select signals * can be enabled. Priority is given to the first select signal. * * Returns the output of the Mux tree. */ object PriorityMux { def apply[T <: Data](in: Seq[(Bool, T)]): T = SeqUtils.priorityMux(in) def apply[T <: Data](sel: Seq[Bool], in: Seq[T]): T = apply(sel zip in) def apply[T <: Data](sel: Bits, in: Seq[T]): T = apply((0 until in.size).map(sel(_)), in) }
回答用のスケルトンは以下になる。
class MyRoutingArbiter(numChannels: Int) extends Module { val io = IO(new Bundle { val in = Vec(numChannels, Flipped(Decoupled(UInt(8.W)))) val out = Decoupled(UInt(8.W)) } ) // Your code here ??? } // verify that the computation is correct class MyRoutingArbiterTester(c: MyRoutingArbiter) extends PeekPokeTester(c) { // Set input defaults for(i <- 0 until 4) { poke(c.io.in(i).valid, 0) poke(c.io.in(i).bits, i) poke(c.io.out.ready, 1) } expect(c.io.out.valid, 0) // Check single input valid behavior with backpressure for (i <- 0 until 4) { poke(c.io.in(i).valid, 1) expect(c.io.out.valid, 1) expect(c.io.out.bits, i) poke(c.io.out.ready, 0) expect(c.io.in(i).ready, 0) poke(c.io.out.ready, 1) poke(c.io.in(i).valid, 0) } // Basic check of multiple input ready behavior with backpressure poke(c.io.in(1).valid, 1) poke(c.io.in(2).valid, 1) expect(c.io.out.bits, 1) expect(c.io.in(1).ready, 1) expect(c.io.in(0).ready, 0) poke(c.io.out.ready, 0) expect(c.io.in(1).ready, 0) } val works = Driver(() => new MyRoutingArbiter(4)) { c => new MyRoutingArbiterTester(c) } assert(works) // Scala Code: if works == false, will throw an error println("SUCCESS!!") // Scala Code: if we get here, our tests passed!
結果は当たり前だがテストにPASSする。 ここまでを踏まえて実装した自分の解答 - クリックすると開くので、見たくない場合は開かないように注意。
class MyRoutingArbiter(numChannels: Int) extends Module {
val io = IO(new Bundle {
val in = Vec(Flipped(Decoupled(UInt(8.W))), numChannels)
val out = Decoupled(UInt(8.W))
} )
// YOUR CODE BELOW
io.out.valid := io.in.map(_.valid).reduce(_ || _)
val channel = PriorityMux(
io.in.map(_.valid).zipWithIndex.map { case (valid, index) => (valid, index.U) }
)
io.out.bits := io.in(channel).bits
for ((ready, index) <- io.in.map(_.ready).zipWithIndex) {
ready := io.out.ready && channel === index.U
}
}
[info] [0.001] Elaborating design...
[info] [0.775] Done elaborating.
Total FIRRTL Compile Time: 237.2 ms
Total FIRRTL Compile Time: 35.5 ms
End of dependency graph
Circuit state created
[info] [0.002] SEED 1548077180980
test cmd3HelperMyRoutingArbiter Success: 17 tests passed in 5 cycles taking 0.029090 seconds
[info] [0.015] RAN 0 CYCLES PASSED
SUCCESS!!
ちなみに前にも書いた気もするけどPriorityMux
を使って作られたMux回路をVerilogに変換すると以下のような3項演算子を使った回路になる。そのため実際の設計で使う際には組み方に気をつけないとタイミング面で面倒なことになるので気をつけておきたい。
module cmd3HelperMyRoutingArbiter( // @[:@3.2] input clock, // @[:@4.4] input reset, // @[:@5.4] output io_in_0_ready, // @[:@6.4] input io_in_0_valid, // @[:@6.4] input [7:0] io_in_0_bits, // @[:@6.4] output io_in_1_ready, // @[:@6.4] input io_in_1_valid, // @[:@6.4] input [7:0] io_in_1_bits, // @[:@6.4] output io_in_2_ready, // @[:@6.4] input io_in_2_valid, // @[:@6.4] input [7:0] io_in_2_bits, // @[:@6.4] input io_out_ready, // @[:@6.4] output io_out_valid, // @[:@6.4] output [7:0] io_out_bits // @[:@6.4] ); wire _T_43; // @[cmd3.sc 7:47:@8.4] wire _T_44; // @[cmd3.sc 7:47:@9.4] wire [1:0] _T_48; // @[Mux.scala 31:69:@11.4] wire [1:0] channel; // @[Mux.scala 31:69:@12.4] wire [7:0] _GEN_5; // @[cmd3.sc 11:15:@13.4] wire [7:0] _GEN_8; // @[cmd3.sc 11:15:@13.4] wire _T_52; // @[cmd3.sc 13:38:@14.4] wire _T_53; // @[cmd3.sc 13:27:@15.4] wire _T_55; // @[cmd3.sc 13:38:@17.4] wire _T_56; // @[cmd3.sc 13:27:@18.4] wire _T_58; // @[cmd3.sc 13:38:@20.4] wire _T_59; // @[cmd3.sc 13:27:@21.4] assign _T_43 = io_in_0_valid | io_in_1_valid; // @[cmd3.sc 7:47:@8.4] assign _T_44 = _T_43 | io_in_2_valid; // @[cmd3.sc 7:47:@9.4] // 最も小さいチャネルのvalidが優先される。 assign _T_48 = io_in_1_valid ? 2'h1 : 2'h2; // @[Mux.scala 31:69:@11.4] assign channel = io_in_0_valid ? 2'h0 : _T_48; // @[Mux.scala 31:69:@12.4] // channelに従って出力するデータを決定する assign _GEN_5 = 2'h1 == channel ? io_in_1_bits : io_in_0_bits; // @[cmd3.sc 11:15:@13.4] assign _GEN_8 = 2'h2 == channel ? io_in_2_bits : _GEN_5; // @[cmd3.sc 11:15:@13.4] assign _T_52 = channel == 2'h0; // @[cmd3.sc 13:38:@14.4] assign _T_53 = io_out_ready & _T_52; // @[cmd3.sc 13:27:@15.4] assign _T_55 = channel == 2'h1; // @[cmd3.sc 13:38:@17.4] assign _T_56 = io_out_ready & _T_55; // @[cmd3.sc 13:27:@18.4] assign _T_58 = channel == 2'h2; // @[cmd3.sc 13:38:@20.4] assign _T_59 = io_out_ready & _T_58; // @[cmd3.sc 13:27:@21.4] assign io_in_0_ready = _T_53; assign io_in_1_ready = _T_56; assign io_in_2_ready = _T_59; assign io_out_valid = _T_44; assign io_out_bits = _GEN_8; endmodule
ということでModule3.3はこれでお終い。 次からはModule3.4関数型言語でScalaの持つ関数型言語の特徴をChiselで使うこと考えていくようだ。