scala/scala3
typeclass
wefree
2024. 3. 22. 00:21
OLD but Good 방법
trait Shape[A] {
def area(a: A): Double
}
object Shape {
def area[A](a: A)(using x: Shape[A]): Double = x.area(a)
given Shape[Rect] = new Shape[Rect]{
override def area(a: Rect): Double = a.width * a.height
}
given Shape[Circle] = new Shape[Circle] {
override def area(a: Circle): Double = 3.14 * a.radius * a.radius
}
}
object ShapeSyntax {
extension [A](a: A) {
def area(using Shape[A]): Double = Shape.area(a)
}
}
case class Rect(width: Double, height: Double)
case class Circle(radius: Double)
object Test {
def main(args: Array[String]): Unit = {
val rect: Rect = Rect(3, 4)
val circle: Circle = Circle(10)
import Shape.given
import ShapeSyntax.*
val rectArea: Double= rect.area
val circleArea: Double = circle.area
println(rectArea) // 12.0
println(circleArea) // 314.0
}
}
scala3 에서는 간결한 형태의 typeclass 정의를 지원한다.
그런데 scala with cats 에서는 OLD 방법을 추천하고 있다.
공식문서: https://docs.scala-lang.org/scala3/reference/contextual/type-classes.html#
(import 된 given 의 내부에 들어있는 extension method 는 자동으로 꺼내져 적용된다?)
예제1
case class Rect(width: Double, height: Double)
case class Circle(radius: Double)
trait Shape[A] {
extension (a: A) {
def area(): Double
}
}
object Shape {
// given listLast[T](using ...): Last[List[T]] = ... 형태로 사용도 가능 (generics 경우)
given Shape[Rect] = new Shape[Rect] {
extension (a: Rect) {
def area(): Double = a.width * a.height
}
}
given circleShape: Shape[Circle] = new Shape[Circle] {
extension (a: Circle) {
def area(): Double = 3.14 * a.radius * a.radius
}
}
def isBig[A](a: A)(using Shape[A]): Boolean = if (a.area() < 100) false else true
}
@main def main(): Unit = {
import Shape.given // 모든 given 을 import
// import Shape.circleShape // 이름이 있는 given 의 경우 명시적으로 그것만 import
// import Shape.{given Shape[Circle], given Shape[Rect]} // 이름이 없는 경우 type 으로 import
val rect: Rect = Rect(2, 3)
println(rect.area()) // 6.0
println(Shape.isBig(rect)) // false
val circle: Circle = Circle(10)
println(circle.area()) // 314.0
println(Shape.isBig(circle)) // true
}
예제2
enum Fig2D {
case Rect(w: Double, h: Double)
case Circle(r: Double)
}
import Fig2D.*
trait Shape[A] {
extension (a: A) def area(): Double
}
object Shape {
given Shape[Fig2D] = new Shape[Fig2D] {
extension (shape: Fig2D) override def area(): Double = shape match {
case Rect(w, h) => w * h
case Circle(r) => r * r * 3.14
}
}
}
object Test {
def main(args: Array[String]): Unit = {
import Shape.given
val rect: Fig2D = Rect(3, 4)
val circle: Fig2D = Circle(2)
println(rect.area())
}
}