在 C++ 中实现柯里函数

对于函数式的编程语言,柯里化是语法(或者语法糖)的一部分。但是 C++ 要实现柯里函数,就要 lambda 套 lambda,就很不方便,于是我就想用模板元编程来实现一个把函数柯里化的函数。

注:以下代码用到了 C++20 的特性。

最初之作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
template<typename F>
auto curry(F&& f) {
if constexpr (std::is_invocable_v<F>) {
return f();
} else {
return [=]<typename T>(T&& x) {
return curry(
[=]<typename... Ts>(Ts&&... xs)->std::invoke_result_t<F, T, Ts...> {
return std::invoke(f, x, xs...);
}
);
};
}
}

int main() {
auto f = [](int a, int b, int c){ return a + b + c; };
std::cout << curry(f)(1)(2)(3) << std::endl; // output: 6
}

这是一个能用的柯里化函数(大概)。

试图否定无用 copy

之前的代码都用了按值捕获,从而产生了无用的 copy。

以下是用于测试的代码(先不考虑函数的 copy,只考虑参数,因为函数的 copy 也会导致参数的 copy):

1
2
3
4
5
6
7
8
9
10
11
12
13
struct Useless {
Useless() { std::cout << "Constructed" << std::endl; }
Useless(Useless&&) noexcept { std::cout << "Moved" << std::endl; }
Useless(const Useless&) { std::cout << "Copied" << std::endl; }
// ~Useless() { cout << "Deleted" << endl; }
};

int main() {
auto g = [](Useless a, Useless b){};
Useless u;
curry(g)(Useless())(u); // 6 copies
g(Useless(), u); // 1 copies
}

可以发现,curry 产生了 6 次 copy,但是如果按照正常方式调用,只会产生一次 copy。

那我们的目标是什么,也是只有一次 copy 吗?由于柯里化每指定一个参数都会产生一个新的函数(由于是匿名函数,所以更像是一个实现了 operator() 的结构体),之前的所有参数信息就在那个函数中。再一次调用产生结构体时,如果要避免对于之前参数的 move/copy,只能引用,但是会有 dangling reference 的问题(lambda 的 capture 不会延长右值的生命周期)。所以只能用若干次 move 来解决,也就是说,目标是产生一次 copy 以及若干尽可能少的 move,由于之前有 6 次 copy,只剩一个 copy 的话,那么需要用 5 个 move 来替代 copy。

在之前的基础上,试图用 std::forward 来解决一切问题,产生了下面的代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
#define FWD(x) std::forward<decltype(x)>(x)
template<typename F>
auto curry(F&& f) {
if constexpr (std::is_invocable_v<F>) {
return f();
} else {
return [f = FWD(f)]<typename T>(T&& x) mutable {
return curry(
[f = FWD(f), x = FWD(x)]<typename... Ts>(Ts&&... xs) mutable -> std::invoke_result_t<F, T, Ts...> {
return std::invoke(f, FWD(x), FWD(xs)...);
});
};
}
}

对于之前的测试,会产出 1 次 copy 以及 5 次 move,看上去是达成目标了。

其实尝试过程中还做到过完全没有 move(一时不敢相信),但是在 curry 后的函数应用了部分参数并保存到变量后,计算会出错,如果想尝试的话,可以试试下面这个例子。一个典型的错误实现是把全部按值捕获改成全部按引用捕获。

1
2
3
4
5
6
7
8
int main() {
auto add3 = [](int a, int b, int c){ return a + b + c; };
auto a0 = curry(add3);
auto a1 = a0(1);
auto a2 = a1(2);
auto a3 = a2(3);
assert(a3 == 6); // a3 may equals to 9 if curry has wrong implementation
}

想要支持传入引用

对于下面使用了左值引用的例子,它无法通过编译。

1
2
3
4
int x = 0;
curry([](int& x){ x += 1; })(x);
// error: cannot bind non-const lvalue reference of type ‘int&’ to an rvalue of type ‘int’
std::cout << x << std::endl;

但是其实写 curry 是为了接近 functional programming,但是 functional programming 的函数参数不可能会是类似 C++ 中的左值引用,因为引用意味着副作用,所以把这个实现作为答案也合乎情理。然而我选择了继续。

从编译信息可以看出,最后传给被 curry 的匿名函数的并不是 x 的左值引用而是一个右值(如果把 int& 改成 int&& 就能通过编译也印证了这一点)。因此,怀疑 capture list 中的 std::forward 和预计的行为有出入。下面来构造一个更加简单的例子来说明问题:

1
2
3
4
5
6
7
8
9
10
11
12
template<typename T>
void f(T&& x) {
[x = std::forward<T>(x)]() mutable {
x += 1;
}();
}

int main() {
int x = 0;
f(x);
std::cout << x << std::endl; // output: 0
}

这次倒是没有编译错误了,但是可以肯定的是,传入 f 的一定是左值引用,f 内部 lambda 所 capture 的 x 一定不是左值引用。借助 cppinsight.io,可以看到模板展开后大概长这样(手动格式化了一下):

1
2
3
4
5
6
7
8
9
10
11
12
template<>
void f<int&>(int& x) {
class Lambda {
public:
Lambda(int& _x) : x{_x} {}
void operator()() { x += 1; }
private:
int x;
} lambda{std::forward<int&>(x)};

lambda();
}

问题在于 lambda 所形成的匿名类(这很 Java)中用于存放捕获的 x 的变量类型并不是引用。这也很合理,因为在 capture list 中,x 前面并没有 = 号,如果把 [x = ...] 改为 [&x = ...] 就能达到预期的效果。但是事情并没有那么容易,因为用了 universal reference,我们肯定希望对于左值引用和右值引用都能达到预期效果。但是修改后对于传入右值引用,会无法编译,因为没法把右值引用赋值给 lambda 匿名类中的左值引用成员变量。

此时你可能想问,那为什么成员变量里不放右值引用呢?理由依旧是 dangling reference,现在 f 是在函数体内直接调用 lambda,试想一下,f 改成返回 lambda,然后在别的地方调用。此时传入 f 的右值引用很可能已经结束了生命周期,然后 undefined behavior 就发生了。

考虑一下我想要的到底是什么,如果是 T&,我想用 [&x] 来捕获,如果是 T&&,那么我想用 [x = std::move(x)] 来捕获,代价是一次 move,那么写出来就是

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
template<typename T>
auto f(T&& x) {
if constexpr (std::is_lvalue_reference_v<T&&>) {
return [&x]() mutable {
x += 1;
};
} else {
return [x = std::move(x)]() mutable {
x += 1;
};
}
}

int main() {
int x = 0;
f(x)();
f(1)();
std::cout << x << std::endl; // output: 1
}

考虑简化以上代码。在查阅资料的时候发现 std::ref 例子中的 bind 也出现了类似的非预期行为,而解决的方法就是使用 std::ref。所以可以写一个新的 forward,根据情况返回 move 后的值或者一个 std::ref 包装的引用。

1
2
3
4
5
6
7
8
template<typename T>
constexpr auto capture_forward(auto&& x) {
if constexpr (std::is_lvalue_reference_v<T&&>) {
return std::ref(std::forward<T>(x));
} else {
return std::forward<T>(x);
}
}

然后把它应用到 curry 中(用到宏并不是因为我喜欢宏,只是因为它看上去更直观一些):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#define CAP_FWD(x) capture_forward<decltype(x)>(x)
#define STD_FWD(x) std::forward<decltype(x)>(x)
template<typename F>
auto curry(F&& f) {
if constexpr (std::is_invocable_v<F>) {
return f();
} else {
return [f = CAP_FWD(f)]<typename T>(T&& x) mutable {
return curry(
[f = CAP_FWD(f), x = CAP_FWD(x)]<typename... Ts>(Ts&&... xs) mutable -> std::invoke_result_t<F, T, Ts...> {
return std::invoke(f, STD_FWD(x), STD_FWD(xs)...);
});
};
}
}

不再拒绝重载函数

然而还是有问题,curry 理论上应该支持任意函数的柯里化,如果函数是被重载的,那么会根据传入参数的类型来决定。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
// overloaded is copy from https://en.cppreference.com/w/cpp/utility/variant/visit
template<class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
// explicit deduction guide (not needed as of C++20) (zerol: still needed in my case)
template<class... Ts> overloaded(Ts...) -> overloaded<Ts...>;

void h(int&&) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
void h(const int&) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
void hh(auto&&) {}

int main() {
curry(overloaded {
[](int&&){ std::cout << "int&&" << std::endl; },
[](const int&){ std::cout << "const int&" << std::endl; },
})(1); // output nothing
curry([](auto&& x){ std::cout << typeid(x).name() << std::endl; })(1); // output: i (i means int)
curry(h)(1); // compile error: no matching function for call to ‘curry(<unresolved overloaded function type>)’
curry(hh)(1); // same as above
curry([](auto&& x){ /* type of x is int&& */ })(1); // compiles
int x;
curry([](auto&& x){ /* type of x is something like std::ref_wrapper<int> */ })(x); // compiles
}

根据结果,看上去对于 overloaded 的函数不大行,但是对于带 template 的函数还行。但是还有一个有趣的现象,对于 hh 这个函数不行,但是对于 lambda 版本的 hh 是可以的(之后会得到解释)。

根据错误信息,从一篇博客中抄来了 RESOLVE_OVERLOAD 这样一个宏。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#define RESOLVE_OVERLOAD(...) \
[](auto&&...args){ return __VA_ARGS__(std::forward<decltype(args)>(args)...); }

void h(int&&) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
void h(const int&) { std::cout << __PRETTY_FUNCTION__ << std::endl; }

int main() {
std::invoke(h, 1); // won't compile, no matching function for call to 'invoke', note: candidate template ignored: couldn't infer template argument '_Callable'
std::invoke(static_cast<void(*)(int&&)>(h), 1); // compiles
std::invoke(RESOLVE_OVERLOAD(h), 1); // compiles!
std::invoke(curry(h), 1); // won't compile, candidate template ignored: couldn't infer template argument 'F'

std::invoke(curry(RESOLVE_OVERLOAD(h)), 1); // no magic happens... error: no matching function for call to 'h'
// note: candidate function not viable: requires 1 argument, but 0 were provided
// void h(const int&) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
// ^
// note: candidate function not viable: requires 1 argument, but 0 were provided
// void h(int&&) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
// ^
}

std::invokecurry 都会把函数作为模板传入,因此 std::invoke 不能处理的重载问题,curry 更没法处理。除了手动指定类型,还有 RESOLVE_OVERLOAD 的方法,就是把重载函数替换成一个接受任意参数然后完美转发给原先函数的 lambda。为什么这会有用呢?因为决定 h 的类型的时机变了,从模板参数变成了一个 lambda 内部被调用的时候。在 lambda 被实例化的时候,args 的类型都是确定的,那么此时也就可以推断出 h 到底重载了哪个函数。当把 RESOLVED_OVERLOADED 应用于传入 curry 的重载函数后,奇迹并没有改善,根据错误提示,最后根本就没有参数传入 RESOLVE_OVERLOAD 包的那层 lambda。原因在于 curry 中判断了一下 std::is_invocable_v,判断的时候 lambda 被实例化了,然后发现找不到对应的 h,此时并非是实例化失败而是找不到签名合适的 h。所以不行的根本原因在于,RESOLVE_OVERLOAD 基于这样一个假设,lambda 被调用(或者实例化)时它的类型是正确的,借此决定合适的重载函数,而 curry 需要做错误的试探。所以要解决这个问题,需要的是在编译时检查一个重载函数能否被一些特定的参数调用,不能的话就让 RESOLVE_OVERLOAD lambda 实例化失败,于是就出现了之后的改进。

1
2
3
4
5
6
7
8
9
#define RESOLVE_OVERLOAD(...) \
[]<typename... Ts, typename = decltype(__VA_ARGS__(std::declval<Ts>()...))> (Ts&&... xs) { \
return __VA_ARGS__(std::forward<Ts>(xs)...); \
}

#define RESOLVE_OVERLOAD_VAR(...) \
[__VA_ARGS__]<typename... Ts, typename = decltype(__VA_ARGS__(std::declval<Ts>()...))> (Ts&&... xs) { \
return __VA_ARGS__(std::forward<Ts>(xs)...); \
}

改进之后就解决了问题,但是出现了两个版本 RESOLVE_OVERLOADRESOLVE_OVERLOAD_VAR

  • 对于变量(比如 lambda)形式存在的重载函数,使用 RESOLVE_OVERLOAD_VAR 来捕获它。
  • 对于函数,不能捕获,使用 RESOLVE_OVERLOAD
  • 对于右值,使用 RESOLVE_OVERLOAD。但建议赋值到变量再搞,不然可能出现 lambda expression in an unevaluated operand 之类的错误(clang 会编译错误,但 gcc 可以,这是 C++20 特性,怀疑 clang 没有实现这个特性)。

以下为测试代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
void h(int&&) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
void h(const int&) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
void hh(auto&&) { std::cout << __PRETTY_FUNCTION__ << std::endl; }

int main() {
curry(RESOLVE_OVERLOAD(overloaded {
[](int&&){ std::cout << "int&&" << std::endl; },
[](const int&){ std::cout << "const int&" << std::endl; },
}))(1); // (can not compile in clang 10.0.1) output (in gcc): int&&
auto hhh = overloaded {
[](int&&){ std::cout << "int&&" << std::endl; },
[](const int&){ std::cout << "const int&" << std::endl; },
};
curry(RESOLVE_OVERLOAD_VAR(hhh))(1); // output: int&&
curry(RESOLVE_OVERLOAD(h))(1); // output: void h(int&&)
curry(RESOLVE_OVERLOAD(hh))(1); // output: void hh(auto:45&&) [with auto:45 = int]
}

真的正确了吗

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
template<typename T>
constexpr auto capture_forward(auto&& x) {
if constexpr (std::is_lvalue_reference_v<T&&>) {
return std::ref(std::forward<T>(x));
} else {
return std::forward<T>(x);
}
}

#define CAP_FWD(x) capture_forward<decltype(x)>(x)
#define STD_FWD(x) std::forward<decltype(x)>(x)
template<typename F>
auto curry(F&& f) {
if constexpr (std::is_invocable_v<F>) {
return f();
} else {
return [f = CAP_FWD(f)]<typename T>(T&& x) mutable {
return curry(
[f = CAP_FWD(f), x = CAP_FWD(x)]<typename... Ts>(Ts&&... xs) mutable -> std::invoke_result_t<F, T, Ts...> {
return std::invoke(f, STD_FWD(x), STD_FWD(xs)...);
});
};
}
}

#define RESOLVE_OVERLOAD(...) \
[]<typename... Ts, typename = decltype(__VA_ARGS__(std::declval<Ts>()...))> (Ts&&... xs) { \
return __VA_ARGS__(std::forward<Ts>(xs)...); \
}

#define RESOLVE_OVERLOAD_VAR(...) \
[__VA_ARGS__]<typename... Ts, typename = decltype(__VA_ARGS__(std::declval<Ts>()...))> (Ts&&... xs) { \
return __VA_ARGS__(std::forward<Ts>(xs)...); \
}

// call curry(RESOLVE_OVERLOAD(f)) for overloaded function f
// call curry(RESOLVE_OVERLOAD_VAR(f)) for overloaded function f (stored in variable)
// or call curry(f) else

我不知道。

一些细节

mutable

没有它的修饰,lambda 形成的匿名类中的捕获的东西就没法是左值引用。如果需要修改 capture 的东西,必须得加 mutable

std::invoke

其实 std::invoke(f, ...) 本来是写成 f(...),但是 std::invoke 做的事情并没有看上去那么简单,比如它可以作用于成员函数(this 得作为第一个参数传入),可以看一下它的 possible implementation,还是挺复杂的(我可不想把它引入 curry 的实现)。

手动指定 lambda 返回值

代码中有这样一句 -> std::invoke_result_t<F, T, Ts...>,本来是省略的,但是会导致过不了编译。曾经抱着试试看的心态手动制定了返回值,结果就能过编译了。但是那个 lambda 只有一个 return,如果把指定的返回值改成 decltype(std::invoke(...)) 也能通过编译。我一直以为 lambda 的返回值就是对于 return 语句中的表达式做了 decltype(如果只有一个 return 的话),但看上去并非如此。我也不知道为什么。

此处参考 https://en.cppreference.com/w/cpp/language/template_argument_deduction#auto-returning_functions

参考

期间打开过无数 StackOverflow,肯定还漏了一些有用的参考: