ハードウェアの気になるあれこれ

技術的に興味のあることを調べて書いてくブログ。主にハードウェアがネタ。

Chisel Bootcamp - Module3.4(2) - 関数型言語の特徴を活かしたChiselのHW設計

スポンサーリンク

前回のChiselの記事ではChisel-bootcampのModule3.4に入りChiselのベースになっているScala関数型言語としての特徴についてを確認した。

www.tech-diningyo.info

今日も引き続きModule3.4を勉強するが、今日はいよいよ関数型言語の特徴をどうChiselに活かすかという部分についてを見ていく。

関数型言語

Chiselでの関数プログラミング

例題: FIRフィルタ

ここでは以前に作成したFIRフィルタを通して、関数型言語の特徴をChiselに適用する方法について確認していく。 参考までに以前に作成したFIRフィルタは以下のようなものだった。

class MyManyElementFir(consts: Seq[Int], bitWidth: Int) extends Module {
  val io = IO(new Bundle {
    val in = Input(UInt(bitWidth.W))
    val out = Output(UInt(bitWidth.W))
  })

  val regs = mutable.ArrayBuffer[UInt]()
  for(i <- 0 until consts.length) {
      if(i == 0) regs += io.in
      else       regs += RegNext(regs(i - 1), 0.U)
  }
  
  val muls = mutable.ArrayBuffer[UInt]()
  for(i <- 0 until consts.length) {
      muls += regs(i) * consts(i).U
  }

  val scan = mutable.ArrayBuffer[UInt]()
  for(i <- 0 until consts.length) {
      if(i == 0) scan += muls(i)
      else scan += muls(i) + scan(i - 1)
  }

  io.out := scan.last
}

ここでは上記の例の様に重みを渡したりする代わりに、”FIRフィルタはどのような計算を行うものか”という関数を渡すことにしてみよう。

// 数学関連の関数のインポート
import scala.math.{abs, round, cos, Pi, pow}

// 単純な三角窓
val TriangularWindow: (Int, Int) => Seq[Int] = (length, bitwidth) => {
  val raw_coeffs = (0 until length).map( (x:Int) => 1-abs((x.toDouble-(length-1)/2.0)/((length-1)/2.0)) )
  val scaled_coeffs = raw_coeffs.map( (x: Double) => round(x * pow(2, bitwidth)).toInt)
  scaled_coeffs
}

// ハミング窓
val HammingWindow: (Int, Int) => Seq[Int] = (length, bitwidth) => {
  val raw_coeffs = (0 until length).map( (x: Int) => 0.54 - 0.46*cos(2*Pi*x/(length-1)))
  val scaled_coeffs = raw_coeffs.map( (x: Double) => round(x * pow(2, bitwidth)).toInt)
  scaled_coeffs
}

// 処理の確認! 第一引数は窓の長さ、第二引数はビット幅となる
TriangularWindow(10, 16)
HammingWindow(10, 16)

これはなんの変哲もないScalaの関数で、実行すると以下のようになる。

import scala.math.{abs, round, cos, Pi, pow}

// simple triangular window

TriangularWindow: (Int, Int) => Seq[Int] = <function2>
HammingWindow: (Int, Int) => Seq[Int] = <function2>
res4_3: Seq[Int] = Vector(
  0,
  14564,
  29127,
  43691,
  58254,
  58254,
  43691,
  29127,
  14564,
  0
)
res4_4: Seq[Int] = Vector(
  5243,
  12296,
  30155,
  50463,
  63718,
  63718,
  50463,
  30155,
  12296,
  5243
)

上記の関数を引数として渡すことの出来るChiselのFIRフィルタを作成しよう。このようにすると、既に前回の例でも見かけたように渡す関数を切り替えることで処理の中身を切り替えることが可能となる。

// このFIRフィルタは窓関数の長さとIOのビット幅、そして窓関数がパラメタライズされている。
class MyFir(length: Int, bitwidth: Int, window: (Int, Int) => Seq[Int]) extends Module {
  val io = IO(new Bundle {
    val in = Input(UInt(bitwidth.W))
    val out = Output(UInt((bitwidth*2+length-1).W)) // expect bit growth, conservative but lazy
  })

  // 与えられた窓関数からcoeffを計算してUIntsにマップする。
  val coeffs = window(length, bitwidth).map(_.U)
  
  // 遅延データ用の配列を確保
  //  → ここではChiselの回路としての動的なインデックスへのアクセスが
  //     必要ないのでVecを使用していない。
  val delays = Seq.fill(length)(Wire(UInt(bitwidth.W))).scan(io.in)( (prev: UInt, next: UInt) => {
    next := RegNext(prev)
    next
  })
  
  // 乗算した結果を"mults"に接続する
  val mults = delays.zip(coeffs).map{ case(delay: UInt, coeff: UInt) => delay * coeff }
  
  // ビット幅の拡張付きで乗算結果を足し合わせる
  val result = mults.reduce(_+&_)

  // 結果を出力に接続
  io.out := result
}

例題: FIRフィルタのテスト回路

前回のFIRフィルタの回ではテスト用の回路を作成する際のゴールデンモデルとして自作のFIRフィルタ関数を使用した。 今回はScala線形代数や信号処理のためのライブラリであるBreezeを使ってテストを構築している。

github.com

正直なところScalaはまだライブラリを使った経験が浅いのでこのライブラリもこのModuleをやった際に初めて知った。。 ざっとWikiを眺めた見た感じ、pythonのnumpyやscipyと同じようなライブラリに見える。このライブラリの中にはフィルタ関係の処理も含まれているようで今回はこれを使って構築をするようだ。 個人的にはこういった形で既に開発されて世の中でも使われている信頼のおけるライブラリをそのまま使ってHW用のテストを書けるというのもChiselの強みだと思う。 画像処理や暗号化などなど、期待値の生成に使えたらありがたいものは山ほどあるだろう。

// 数学系のライブラリインポート
import scala.math.{pow, sin, Pi}
import breeze.signal.{filter, OptOverhang}
import breeze.signal.support.{CanFilter, FIRKernel1D}
import breeze.linalg.DenseVector

// テストのパラメータ
val length = 7
val bitwidth = 12 // 15未満の数字が必須条件。それ以上だとInt型では処理できるBigIntが必要になる。
val window = TriangularWindow

// FIRフィルタのテスト
Driver(() => new MyFir(length, bitwidth, window)) {
  c => new PeekPokeTester(c) {
    
    // テストデータ
    val n = 100 // input length
    val sine_freq = 10
    val samp_freq = 100
    
    // サンプルデータ。 0-2^(bitWidth)までスケール可能
    val max_value = pow(2, bitwidth)-1
    val sine = (0 until n).map(i => (max_value/2 + max_value/2*sin(2*Pi*sine_freq/samp_freq*i)).toInt)
    //println(s"input = ${sine.toArray.deep.mkString(", ")}")
    
    // 係数
    val coeffs = window(length, bitwidth)
    //println(s"coeffs = ${coeffs.toArray.deep.mkString(", ")}")

    // breezeのフィルタをゴールデンモデルとする。そのために逆順の係数が必要
    val expected = filter(DenseVector(sine.toArray), 
                          FIRKernel1D(DenseVector(coeffs.reverse.toArray), 1.0, ""), 
                          OptOverhang.None)
    //println(s"exp_out = ${expected.toArray.deep.mkString(", ")}")

    // FIRフィルタにデータを入力し、結果を確認する
    reset(5)
    for (i <- 0 until n) {
      poke(c.io.in, sine(i))
      if (i >= length-1) { // データが0埋めされていないので、全てのレジスタの初期化を待つ
        expect(c.io.out, expected(i-length+1))
        //println(s"cycle $i, got ${peek(c.io.out)}, expect ${expected(i-length+1)}")
      }
      step(1)
    }
  }
}

とりあえず今日はここまで。。 ほんとは演習もやろうと思ったんだけど、思いの外ボリュームがありそうで。。。。因みに演習は”ニューラルネットワークニューロンを作る”というものだ。ここまでやってきた高階関数関数プログラミングがうまい具合にハマる例の一つだと思う。