ハードウェアの気になるあれこれ

技術的に興味のあることを調べて書いてくブログ。主にハードウェアがネタ。

Chisel Bootcamp - Module3.3(2) - 高階関数 - map等を使ったアービターの設計

スポンサーリンク

前回のChiselの記事では久しぶりにChisel-bootcampの学習に戻りScala高階関数についてを学習した。

www.tech-diningyo.info

今日も引き続き高階関数の章でもう少し例題の確認と練習問題に取り組んでいく。

高階関数

mapの練習問題から再開する。

練習問題:map

問題文は以下になる。

// 入力のリストを倍にするように`???`を埋めよう
// これは次の値を返すはずだ: List(2, 4, 6, 8)
println(List(1, 2, 3, 4).map(???))

解答 - クリックすると開くので、見たくない場合は開かないように注意。

println(List(1, 2, 3, 4).map(_ * 2)

例題:zipWithIndex

zipWithIndexの定義はzipWithIndex: List[(A, Int)]となっており、引数を取らないがリストの各要素にインデックスを付与したタプルのリストを返すという処理を行う。 言葉では若干わかりづらいが、実行してみれば一目瞭然だ。・

println(List(1, 2, 3, 4).zipWithIndex)  // インデックスは0からスタートする
println(List("a", "b", "c", "d").zipWithIndex)
// タプルがネストするようなケースも可能
println(List(("a", "b"), ("c", "d"), ("e", "f"), ("g", "h")).zipWithIndex)

実行すると以下のようになる。

List((1,0), (2,1), (3,2), (4,3))
List((a,0), (b,1), (c,2), (d,3))
List(((a,b),0), ((c,d),1), ((e,f),2), ((g,h),3))

pythonのenumerateの処理に似てるといえば似てる。

例題:reduce

これはすでに使っているが、リストの各要素同士で処理を適用するという処理になる。定義はreduce(op: (A, A) ⇒ A): A

println(List(1, 2, 3, 4).reduce((a, b) => a + b))  // returns the sum of all the elements
println(List(1, 2, 3, 4).reduce(_ * _))  // returns the product of all the elements
println(List(1, 2, 3, 4).map(_ + 1).reduce(_ + _))  // you can chain reduce onto the result of a map

結果はListの中身であるInt型がreduce内部の処理で計算されてInt型として返ってくる。

10
24
14

なおreduceはリストの中身が空の場合は以下のようなエラーが発生する。

java.lang.UnsupportedOperationException: empty.reduceLeft
  scala.collection.LinearSeqOptimized$class.reduceLeft(LinearSeqOptimized.scala:137)
  scala.collection.immutable.List.reduceLeft(List.scala:84)
  scala.collection.TraversableOnce$class.reduce(TraversableOnce.scala:208)
  scala.collection.AbstractTraversable.reduce(Traversable.scala:104)
  ammonite.$sess.cmd7$Helper.<init>(cmd7.sc:1)
  ammonite.$sess.cmd7$.<init>(cmd7.sc:7)
  ammonite.$sess.cmd7$.<clinit>(cmd7.sc:-1)

練習問題:reduce

問題文はこちら。

// 入力されたリストの各要素を2倍した値同士で積を取るように???を埋めよう
// これは次の値を返すはずだ: (1*2)*(2*2)*(3*2)*(4*2) = 384
println(List(1, 2, 3, 4).map(???).reduce(???))

解答 - クリックすると開くので、見たくない場合は開かないように注意。

println(List(1, 2, 3, 4).map(_ * 2).reduce(_ * _))

例題:fold

foldreduceと似ている処理だ。定義はfold(z: A)(op: (A, A) ⇒ A): Aとなっておりreduceとの違いはopの前にzが追加されていることだ。このzはリストの最初に追加される要素となる。言い換えると初期値付きのrecudeと言ったところか。

println(List(1, 2, 3, 4).fold(0)(_ + _))  // equivalent to the sum using reduce
println(List(1, 2, 3, 4).fold(1)(_ + _))  // like above, but accumulation starts at 1
println(List().fold(1)(_ + _))  // unlike reduce, does not fail on an empty input

実行するとすべてただのIntが返却される。

10
11
1

1つ目と2つ目はzに入れる値が異なるため結果が1だけずれているのがわかると思う。またreduceとは異なり初期値を持たせることができるため、3つ目の結果の様に空のリストを入力してもエラーとならずに初期値が返却される。

練習問題:fold

問題文は以下。

// 入力リストの各要素の積を2倍にした結果が得られるように???を埋めよう
// このリストを与えると結果は48になるはずだ。
println(List(1, 2, 3, 4).fold(2)(_ * _))

解答 - クリックすると開くので、見たくない場合は開かないように注意。

println(List(1, 2, 3, 4).fold(2)(_ * _))

練習問題:Decoupledを使ったアービター

ここでChiselに戻って練習問題に取り組んでみよう。練習問題を訳したものを載せておく。

ここまでのすべて使った練習問題だ。この例題ではDecoupledアービターを作成する。 その仕様は以下のようなものだ:

  • n 個のDecoupldedの入力を持ち、1つのDecoupledの出力を持つ。
    • この時出力されるDecoupledはvalid信号を上げたうち一番小さいチャネルとなる。

いくつかヒントを与えておこう:

  • アーキテクチャ面からのヒント
    • io.out.validは入力のvalidのいずれかが上がればtureとなる。
    • 選択されたチャネル用の内部wireを持つことを検討しよう。
    • 入力のreadyはそのチャネルが選択された状態でoutputのreadyが上がるとtrueになる。(これは組み合わせ論理のvalid/readyになるが今は無視しよう)
  • 次の構造が助けになる:
    • 入力の各要素を返却するようなmap:例えばio.in.map(_.valid)は入力のBundleのvalidを返却する
    • PriorityMux(List[Bits, Bool])bitsデータとvalid信号のリストを引数にして、validがたった先頭要素を返却する
    • Vecの動的インデックス。インデックスはUIntで指定する:例えばio.in(0.U)

ちなみに上記のヒントにはPriorityMux(List[Bits, Bool])が書いてあるが、PriorityMuxには以下の様に3つの使い方が用意されておりList[Bits, Bool]を使うよりも、違う形態を使ったほうがもっとシンプルに書ける。

/** Builds a Mux tree under the assumption that multiple select signals
  * can be enabled. Priority is given to the first select signal.
  *
  * Returns the output of the Mux tree.
  */
object PriorityMux {
  def apply[T <: Data](in: Seq[(Bool, T)]): T = SeqUtils.priorityMux(in)
  def apply[T <: Data](sel: Seq[Bool], in: Seq[T]): T = apply(sel zip in)
  def apply[T <: Data](sel: Bits, in: Seq[T]): T = apply((0 until in.size).map(sel(_)), in)
}

回答用のスケルトンは以下になる。

class MyRoutingArbiter(numChannels: Int) extends Module {
  val io = IO(new Bundle {
    val in = Vec(numChannels, Flipped(Decoupled(UInt(8.W))))
    val out = Decoupled(UInt(8.W))
  } )

  // Your code here
  ???
}

// verify that the computation is correct
class MyRoutingArbiterTester(c: MyRoutingArbiter) extends PeekPokeTester(c) {
  // Set input defaults
  for(i <- 0 until 4) {
    poke(c.io.in(i).valid, 0)
    poke(c.io.in(i).bits, i)
    poke(c.io.out.ready, 1)
  }

  expect(c.io.out.valid, 0)

  // Check single input valid behavior with backpressure
  for (i <- 0 until 4) {
    poke(c.io.in(i).valid, 1)
    expect(c.io.out.valid, 1)
    expect(c.io.out.bits, i)

    poke(c.io.out.ready, 0)
    expect(c.io.in(i).ready, 0)

    poke(c.io.out.ready, 1)
    poke(c.io.in(i).valid, 0)
  }

  // Basic check of multiple input ready behavior with backpressure
  poke(c.io.in(1).valid, 1)
  poke(c.io.in(2).valid, 1)
  expect(c.io.out.bits, 1)
  expect(c.io.in(1).ready, 1)
  expect(c.io.in(0).ready, 0)

  poke(c.io.out.ready, 0)
  expect(c.io.in(1).ready, 0)
}

val works = Driver(() => new MyRoutingArbiter(4)) {
  c => new MyRoutingArbiterTester(c)
}
assert(works)        // Scala Code: if works == false, will throw an error
println("SUCCESS!!") // Scala Code: if we get here, our tests passed!

ここまでを踏まえて実装した自分の解答 - クリックすると開くので、見たくない場合は開かないように注意。

class MyRoutingArbiter(numChannels: Int) extends Module {
  val io = IO(new Bundle {
    val in = Vec(Flipped(Decoupled(UInt(8.W))), numChannels)
    val out = Decoupled(UInt(8.W))
  } )

  // YOUR CODE BELOW
  io.out.valid := io.in.map(_.valid).reduce(_ || _)
  val channel = PriorityMux(
    io.in.map(_.valid).zipWithIndex.map { case (valid, index) => (valid, index.U) }
  )
  io.out.bits := io.in(channel).bits
  for ((ready, index) <- io.in.map(_.ready).zipWithIndex) {
    ready := io.out.ready && channel === index.U
  }
}

結果は当たり前だがテストにPASSする。

[info] [0.001] Elaborating design...
[info] [0.775] Done elaborating.
Total FIRRTL Compile Time: 237.2 ms
Total FIRRTL Compile Time: 35.5 ms
End of dependency graph
Circuit state created
[info] [0.002] SEED 1548077180980
test cmd3HelperMyRoutingArbiter Success: 17 tests passed in 5 cycles taking 0.029090 seconds
[info] [0.015] RAN 0 CYCLES PASSED
SUCCESS!!

ちなみに前にも書いた気もするけどPriorityMuxを使って作られたMux回路をVerilogに変換すると以下のような3項演算子を使った回路になる。そのため実際の設計で使う際には組み方に気をつけないとタイミング面で面倒なことになるので気をつけておきたい。

module cmd3HelperMyRoutingArbiter( // @[:@3.2]
  input        clock, // @[:@4.4]
  input        reset, // @[:@5.4]
  output       io_in_0_ready, // @[:@6.4]
  input        io_in_0_valid, // @[:@6.4]
  input  [7:0] io_in_0_bits, // @[:@6.4]
  output       io_in_1_ready, // @[:@6.4]
  input        io_in_1_valid, // @[:@6.4]
  input  [7:0] io_in_1_bits, // @[:@6.4]
  output       io_in_2_ready, // @[:@6.4]
  input        io_in_2_valid, // @[:@6.4]
  input  [7:0] io_in_2_bits, // @[:@6.4]
  input        io_out_ready, // @[:@6.4]
  output       io_out_valid, // @[:@6.4]
  output [7:0] io_out_bits // @[:@6.4]
);
  wire  _T_43; // @[cmd3.sc 7:47:@8.4]
  wire  _T_44; // @[cmd3.sc 7:47:@9.4]
  wire [1:0] _T_48; // @[Mux.scala 31:69:@11.4]
  wire [1:0] channel; // @[Mux.scala 31:69:@12.4]
  wire [7:0] _GEN_5; // @[cmd3.sc 11:15:@13.4]
  wire [7:0] _GEN_8; // @[cmd3.sc 11:15:@13.4]
  wire  _T_52; // @[cmd3.sc 13:38:@14.4]
  wire  _T_53; // @[cmd3.sc 13:27:@15.4]
  wire  _T_55; // @[cmd3.sc 13:38:@17.4]
  wire  _T_56; // @[cmd3.sc 13:27:@18.4]
  wire  _T_58; // @[cmd3.sc 13:38:@20.4]
  wire  _T_59; // @[cmd3.sc 13:27:@21.4]
  assign _T_43 = io_in_0_valid | io_in_1_valid; // @[cmd3.sc 7:47:@8.4]
  assign _T_44 = _T_43 | io_in_2_valid; // @[cmd3.sc 7:47:@9.4]
  // 最も小さいチャネルのvalidが優先される。
  assign _T_48 = io_in_1_valid ? 2'h1 : 2'h2; // @[Mux.scala 31:69:@11.4]
  assign channel = io_in_0_valid ? 2'h0 : _T_48; // @[Mux.scala 31:69:@12.4]
  // channelに従って出力するデータを決定する
  assign _GEN_5 = 2'h1 == channel ? io_in_1_bits : io_in_0_bits; // @[cmd3.sc 11:15:@13.4]
  assign _GEN_8 = 2'h2 == channel ? io_in_2_bits : _GEN_5; // @[cmd3.sc 11:15:@13.4]
  assign _T_52 = channel == 2'h0; // @[cmd3.sc 13:38:@14.4]
  assign _T_53 = io_out_ready & _T_52; // @[cmd3.sc 13:27:@15.4]
  assign _T_55 = channel == 2'h1; // @[cmd3.sc 13:38:@17.4]
  assign _T_56 = io_out_ready & _T_55; // @[cmd3.sc 13:27:@18.4]
  assign _T_58 = channel == 2'h2; // @[cmd3.sc 13:38:@20.4]
  assign _T_59 = io_out_ready & _T_58; // @[cmd3.sc 13:27:@21.4]
  assign io_in_0_ready = _T_53;
  assign io_in_1_ready = _T_56;
  assign io_in_2_ready = _T_59;
  assign io_out_valid = _T_44;
  assign io_out_bits = _GEN_8;
endmodule

ということでModule3.3はこれでお終い。 次からはModule3.4関数型言語Scalaの持つ関数型言語の特徴をChiselで使うこと考えていくようだ。