scala/scala3

typeclass

wefree 2024. 3. 22. 00:21

OLD but Good 방법

trait Shape[A] {
  def area(a: A): Double
}

object Shape {
  def area[A](a: A)(using x: Shape[A]): Double = x.area(a)

  given Shape[Rect] = new Shape[Rect]{
    override def area(a: Rect): Double = a.width * a.height
  }

  given Shape[Circle] = new Shape[Circle] {
    override def area(a: Circle): Double = 3.14 * a.radius * a.radius
  }
}

object ShapeSyntax {
  extension [A](a: A) {
    def area(using Shape[A]): Double = Shape.area(a)
  }
}

case class Rect(width: Double, height: Double)

case class Circle(radius: Double)

object Test {
  def main(args: Array[String]): Unit = {
    val rect: Rect = Rect(3, 4)
    val circle: Circle = Circle(10)

    import Shape.given
    import ShapeSyntax.*

    val rectArea: Double= rect.area
    val circleArea: Double = circle.area

    println(rectArea)  // 12.0
    println(circleArea)  // 314.0
  }
}

 

 


scala3 에서는 간결한 형태의 typeclass 정의를 지원한다.

그런데 scala with cats 에서는 OLD 방법을 추천하고 있다. 

 

공식문서: https://docs.scala-lang.org/scala3/reference/contextual/type-classes.html#

(import 된 given 의 내부에 들어있는 extension method 는 자동으로 꺼내져 적용된다?)

예제1

case class Rect(width: Double, height: Double)
case class Circle(radius: Double)

trait Shape[A] {
  extension (a: A) {
    def area(): Double
  }
}

object Shape {
  // given listLast[T](using ...): Last[List[T]] = ... 형태로 사용도 가능 (generics 경우)
  given Shape[Rect] = new Shape[Rect] {
    extension (a: Rect) {
      def area(): Double = a.width * a.height
    }
  }

  given circleShape: Shape[Circle] = new Shape[Circle] {
    extension (a: Circle) {
      def area(): Double = 3.14 * a.radius * a.radius
    }
  }

  def isBig[A](a: A)(using Shape[A]): Boolean = if (a.area() < 100) false else true
}

@main def main(): Unit = {
  import Shape.given // 모든 given 을 import
//  import Shape.circleShape  // 이름이 있는 given 의 경우 명시적으로 그것만 import
//  import Shape.{given Shape[Circle], given Shape[Rect]}  // 이름이 없는 경우 type 으로 import

  val rect: Rect = Rect(2, 3)
  println(rect.area()) // 6.0
  println(Shape.isBig(rect)) // false

  val circle: Circle = Circle(10)
  println(circle.area()) // 314.0
  println(Shape.isBig(circle)) // true
}

 

 

예제2

enum Fig2D {
  case Rect(w: Double, h: Double)
  case Circle(r: Double)
}

import Fig2D.*

trait Shape[A] {
  extension (a: A) def area(): Double
}

object Shape {
  given Shape[Fig2D] = new Shape[Fig2D] {
    extension (shape: Fig2D) override def area(): Double = shape match {
      case Rect(w, h) => w * h
      case Circle(r) => r * r * 3.14
    }
  }
}


object Test {
  def main(args: Array[String]): Unit = {
    import Shape.given

    val rect: Fig2D = Rect(3, 4)
    val circle: Fig2D = Circle(2)

    println(rect.area())
  }
}