Line data Source code
1 : /*
2 : * Copyright (c) 2012-2015: G-CSC, Goethe University Frankfurt
3 : * Author: Andreas Vogel
4 : *
5 : * This file is part of UG4.
6 : *
7 : * UG4 is free software: you can redistribute it and/or modify it under the
8 : * terms of the GNU Lesser General Public License version 3 (as published by the
9 : * Free Software Foundation) with the following additional attribution
10 : * requirements (according to LGPL/GPL v3 §7):
11 : *
12 : * (1) The following notice must be displayed in the Appropriate Legal Notices
13 : * of covered and combined works: "Based on UG4 (www.ug4.org/license)".
14 : *
15 : * (2) The following notice must be displayed at a prominent place in the
16 : * terminal output of covered works: "Based on UG4 (www.ug4.org/license)".
17 : *
18 : * (3) The following bibliography is recommended for citation and must be
19 : * preserved in all covered files:
20 : * "Reiter, S., Vogel, A., Heppner, I., Rupp, M., and Wittum, G. A massively
21 : * parallel geometric multigrid solver on hierarchically distributed grids.
22 : * Computing and visualization in science 16, 4 (2013), 151-164"
23 : * "Vogel, A., Reiter, S., Rupp, M., Nägel, A., and Wittum, G. UG4 -- a novel
24 : * flexible software system for simulating pde based models on high performance
25 : * computers. Computing and visualization in science 16, 4 (2013), 165-179"
26 : *
27 : * This program is distributed in the hope that it will be useful,
28 : * but WITHOUT ANY WARRANTY; without even the implied warranty of
29 : * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
30 : * GNU Lesser General Public License for more details.
31 : */
32 :
33 : #ifndef __H__UG_BRIDGE__BRIDGES__USER_DATA__USER_DATA_IMPL_
34 : #define __H__UG_BRIDGE__BRIDGES__USER_DATA__USER_DATA_IMPL_
35 :
36 : #ifdef UG_FOR_LUA
37 : #include "lua_user_data.h"
38 : #endif
39 : #include "lib_disc/spatial_disc/user_data/linker/linker_traits.h"
40 : #include "lib_disc/spatial_disc/user_data/const_user_data.h"
41 :
42 : #include "info_commands.h"
43 : #include "common/util/number_util.h"
44 :
45 : #if 0
46 : #define PROFILE_CALLBACK() PROFILE_FUNC_GROUP("luacallback")
47 : #define PROFILE_CALLBACK_BEGIN(name) PROFILE_BEGIN_GROUP(name, "luacallback")
48 : #define PROFILE_CALLBACK_END() PROFILE_END()
49 : #else
50 : #define PROFILE_CALLBACK()
51 : #define PROFILE_CALLBACK_BEGIN(name)
52 : #define PROFILE_CALLBACK_END()
53 : #endif
54 : namespace ug{
55 :
56 : #ifdef USE_LUA2C
57 : extern bool useLuaCompiler;
58 : #endif
59 :
60 :
61 :
62 : ////////////////////////////////////////////////////////////////////////////////
63 : // LuaUserData
64 : ////////////////////////////////////////////////////////////////////////////////
65 :
66 : template <typename TData, int dim, typename TRet>
67 0 : std::string LuaUserData<TData,dim,TRet>::signature()
68 : {
69 0 : std::stringstream ss;
70 0 : ss << "function name(";
71 0 : if(dim >= 1) ss << "x";
72 0 : if(dim >= 2) ss << ", y";
73 0 : if(dim >= 3) ss << ", z";
74 0 : ss << ", t, si)\n ... \n return ";
75 : if(lua_traits<TRet>::size != 0)
76 0 : ss << lua_traits<TRet>::signature() << ", ";
77 0 : ss << lua_traits<TData>::signature();
78 0 : ss << "\nend";
79 0 : return ss.str();
80 0 : }
81 :
82 :
83 : template <typename TData, int dim, typename TRet>
84 0 : std::string LuaUserData<TData,dim,TRet>::name()
85 : {
86 0 : std::stringstream ss;
87 0 : ss << "Lua";
88 0 : if(lua_traits<TRet>::size > 0) ss << "Cond";
89 0 : ss << "User" << lua_traits<TData>::name() << dim << "d";
90 0 : return ss.str();
91 0 : }
92 :
93 : template <typename TData, int dim, typename TRet>
94 0 : LuaUserData<TData,dim,TRet>::LuaUserData(const char* luaCallback)
95 0 : : m_callbackName(luaCallback), m_bFromFactory(false)
96 : {
97 : // get lua state
98 0 : m_L = ug::script::GetDefaultLuaState();
99 :
100 : // obtain a reference
101 0 : lua_getglobal(m_L, m_callbackName.c_str());
102 :
103 : // make sure that the reference is valid
104 0 : if(lua_isnil(m_L, -1)){
105 0 : UG_THROW(name() << ": Specified lua callback "
106 : "does not exist: " << m_callbackName);
107 : }
108 :
109 : // store reference to lua function
110 0 : m_callbackRef = luaL_ref(m_L, LUA_REGISTRYINDEX);
111 :
112 : // make a test run
113 0 : check_callback_returns(m_L, m_callbackRef, m_callbackName.c_str(), true);
114 :
115 : #ifdef USE_LUA2C
116 : if(useLuaCompiler) m_luaComp.create(luaCallback);
117 : #endif
118 0 : }
119 :
120 : template <typename TData, int dim, typename TRet>
121 0 : LuaUserData<TData,dim,TRet>::LuaUserData(LuaFunctionHandle handle)
122 0 : : m_callbackName("__anonymous__lua__function__"), m_bFromFactory(false)
123 : {
124 : // get lua state
125 0 : m_L = ug::script::GetDefaultLuaState();
126 :
127 : // store reference to lua function
128 0 : m_callbackRef = handle.ref;
129 :
130 : // make a test run
131 0 : check_callback_returns(m_L, m_callbackRef, m_callbackName.c_str(), true);
132 :
133 : #ifdef USE_LUA2C
134 : // UG_THROW("LuaFunctionHandle usage currently not supported with LUA2C.");
135 : if(useLuaCompiler) m_luaComp.create(m_callbackName.c_str(), &handle);
136 : #endif
137 0 : }
138 :
139 :
140 : template <typename TData, int dim, typename TRet>
141 0 : bool LuaUserData<TData,dim,TRet>::
142 : check_callback_returns(lua_State* L, int callbackRef, const char* callName, const bool bThrow)
143 : {
144 : PROFILE_CALLBACK()
145 : // get current stack level
146 0 : const int level = lua_gettop(L);
147 :
148 : // dummy values to invoke the callback once
149 : MathVector<dim> x; x = 0.0;
150 : number time = 0.0;
151 : int si = 0;
152 :
153 : // push the callback function on the stack
154 0 : lua_rawgeti(L, LUA_REGISTRYINDEX, callbackRef);
155 :
156 : // push space coordinates on stack
157 : lua_traits<MathVector<dim> >::push(L, x);
158 :
159 : // push time on stack
160 : lua_traits<number>::push(L, time);
161 :
162 : // push subset on stack
163 : lua_traits<int>::push(L, si);
164 :
165 : // compute total args size
166 : const int argSize = lua_traits<MathVector<dim> >::size
167 : + lua_traits<number>::size
168 : + lua_traits<int>::size;
169 :
170 : // compute total return size
171 : const int retSize = lua_traits<TData>::size + lua_traits<TRet>::size;
172 :
173 : // call lua function
174 0 : if(lua_pcall(L, argSize, LUA_MULTRET, 0) != 0)
175 0 : UG_THROW(name() << ": Error while "
176 : "testing callback '" << callName << "',"
177 : " lua message: "<< lua_tostring(L, -1));
178 :
179 : // get number of results
180 0 : const int numResults = lua_gettop(L) - level;
181 :
182 : // success flag
183 : bool bRet = true;
184 :
185 : // if number of results is wrong return error
186 0 : if(numResults != retSize){
187 0 : if(bThrow){
188 0 : UG_THROW(name() << ": Number of return values incorrect "
189 : "for callback\n"<<callName<< " (" << bridge::GetLUAScriptFunctionDefined(callName) << ")"
190 : "\nRequired: "<<retSize<<", passed: "<<numResults
191 : <<". Use signature as follows:\n"
192 : << signature());
193 : }
194 : else{
195 : bRet = false;
196 : }
197 : }
198 :
199 : // check return value
200 0 : if(!lua_traits<TData>::check(L)){
201 0 : if(bThrow){
202 0 : UG_THROW(name() << ": Data values type incorrect "
203 : "for callback\n"<<callName<< " (" << bridge::GetLUAScriptFunctionDefined(callName) << ")"
204 : "\nUse signature as follows:\n"
205 : << signature());
206 : }
207 : else{
208 : bRet = false;
209 : }
210 : }
211 :
212 : // read return flag (may be void)
213 0 : if(!lua_traits<TRet>::check(L, -retSize)){
214 0 : if(bThrow){
215 0 : UG_THROW("LuaUserData: Return values type incorrect "
216 : "for callback\n"<<callName<< " (" << bridge::GetLUAScriptFunctionDefined(callName) << ")"
217 : "\nUse signature as follows:\n"
218 : << signature());
219 : }
220 : else{
221 : bRet = false;
222 : }
223 : }
224 :
225 : // pop values
226 0 : lua_pop(L, numResults);
227 :
228 : // return match
229 0 : return bRet;
230 : }
231 :
232 : template <typename TData, int dim, typename TRet>
233 : bool LuaUserData<TData,dim,TRet>::
234 : check_callback_returns(LuaFunctionHandle handle, const bool bThrow)
235 : {
236 : PROFILE_CALLBACK()
237 : // get lua state
238 0 : lua_State* L = ug::script::GetDefaultLuaState();
239 :
240 : // forward call
241 0 : bool bRet = check_callback_returns(L, handle.ref, "__lua_function_handle__", bThrow);
242 :
243 : // return match
244 : return bRet;
245 : }
246 :
247 : template <typename TData, int dim, typename TRet>
248 0 : bool LuaUserData<TData,dim,TRet>::
249 : check_callback_returns(const char* callName, const bool bThrow)
250 : {
251 : PROFILE_CALLBACK()
252 : // get lua state
253 0 : lua_State* L = ug::script::GetDefaultLuaState();
254 :
255 : // obtain a reference
256 0 : lua_getglobal(L, callName);
257 :
258 : // check if reference is valid
259 0 : if(lua_isnil(L, -1)) {
260 0 : if(bThrow) {
261 0 : UG_THROW(name() << ": Cannot find specified lua callback "
262 : " with name: "<<callName);
263 : }
264 : else {
265 : return false;
266 : }
267 : }
268 :
269 : // get reference
270 0 : int callbackRef = luaL_ref(L, LUA_REGISTRYINDEX);
271 :
272 : // forward call
273 0 : bool bRet = check_callback_returns(L, callbackRef, callName, bThrow);
274 :
275 : // free reference to callback
276 0 : luaL_unref(L, LUA_REGISTRYINDEX, callbackRef);
277 :
278 : // return match
279 0 : return bRet;
280 : }
281 :
282 : template <typename TData, int dim, typename TRet>
283 0 : TRet LuaUserData<TData,dim,TRet>::
284 : evaluate(TData& D, const MathVector<dim>& x, number time, int si) const
285 : {
286 : PROFILE_CALLBACK()
287 : #ifdef USE_LUA2C
288 : if(useLuaCompiler && m_luaComp.is_valid())
289 : {
290 : double d[dim+2];
291 : for(int i=0; i<dim; i++)
292 : d[i] = x[i];
293 : d[dim] = time;
294 : d[dim+1] = si;
295 : double ret[lua_traits<TData>::size+1];
296 : m_luaComp.call(ret, d);
297 : //TData D2;
298 : TRet *t=NULL;
299 : lua_traits<TData>::read(D, ret, t);
300 : return lua_traits<TRet>::do_return(ret[0]);
301 : }
302 : else
303 : #endif
304 : {
305 : // push the callback function on the stack
306 0 : lua_rawgeti(m_L, LUA_REGISTRYINDEX, m_callbackRef);
307 :
308 : // push space coordinates on stack
309 0 : lua_traits<MathVector<dim> >::push(m_L, x);
310 :
311 : // push time on stack
312 0 : lua_traits<number>::push(m_L, time);
313 :
314 : // push subset index on stack
315 0 : lua_traits<int>::push(m_L, si);
316 :
317 : // compute total args size
318 : const int argSize = lua_traits<MathVector<dim> >::size
319 : + lua_traits<number>::size
320 : + lua_traits<int>::size;
321 :
322 : // compute total return size
323 : const int retSize = lua_traits<TData>::size + lua_traits<TRet>::size;
324 :
325 : // call lua function
326 0 : if(lua_pcall(m_L, argSize, retSize, 0) != 0)
327 0 : UG_THROW(name() << "::operator(...): Error while "
328 : "running callback '" << m_callbackName << "',"
329 : " lua message: "<< lua_tostring(m_L, -1)<<".\n"
330 : "Use signature as follows:\n"
331 : << signature());
332 :
333 : bool res = false;
334 : try{
335 : // read return value
336 0 : lua_traits<TData>::read(m_L, D);
337 :
338 : // read return flag (may be void)
339 0 : lua_traits<TRet>::read(m_L, res, -retSize);
340 : }
341 0 : UG_CATCH_THROW(name() << "::operator(...): Error while running "
342 : "callback '" << m_callbackName << "'.\n"
343 : "Use signature as follows:\n"
344 : << signature());
345 :
346 : // pop values
347 0 : lua_pop(m_L, retSize);
348 :
349 : // forward flag
350 0 : return lua_traits<TRet>::do_return(res);
351 : }
352 : }
353 :
354 : template <typename TData, int dim, typename TRet>
355 0 : LuaUserData<TData,dim,TRet>::~LuaUserData()
356 : {
357 : // free reference to callback
358 0 : luaL_unref(m_L, LUA_REGISTRYINDEX, m_callbackRef);
359 :
360 0 : if(m_bFromFactory)
361 0 : LuaUserDataFactory<TData,dim,TRet>::remove(m_callbackName);
362 0 : }
363 :
364 : ////////////////////////////////////////////////////////////////////////////////
365 : // LuaUserDataFactory
366 : ////////////////////////////////////////////////////////////////////////////////
367 :
368 : template <typename TData, int dim, typename TRet>
369 : SmartPtr<LuaUserData<TData,dim,TRet> >
370 0 : LuaUserDataFactory<TData,dim,TRet>::provide_or_create(const std::string& name)
371 : {
372 : PROFILE_CALLBACK();
373 : typedef std::map<std::string, std::pair<LuaUserData<TData,dim,TRet>*, int*> > Map;
374 : typedef typename Map::iterator iterator;
375 :
376 : // check for element
377 : iterator iter = m_mData.find(name);
378 :
379 : // if name does not exist, create new one
380 0 : if(iter == m_mData.end())
381 : {
382 : SmartPtr<LuaUserData<TData,dim,TRet> > sp
383 0 : = make_sp(new LuaUserData<TData,dim,TRet>(name.c_str()));
384 :
385 : // the LuaUserData must remember to unregister itself at destruction
386 : sp->set_created_from_factory(true);
387 :
388 : // NOTE AND WARNING: This is very hacky and dangerous. We only do this
389 : // since we exactly know what we are doing and everything is save and
390 : // only in protected or private area. However, if you once want to change
391 : // this code, please be aware, that we store here plain pointers and
392 : // associated reference counters of a SmartPtr. This should not be done
393 : // in general and this kind of coding is not recommended at all. Please
394 : // use different approaches whenever possible.
395 0 : std::pair<LuaUserData<TData,dim,TRet>*, int*>& data = m_mData[name];
396 0 : data.first = sp.get();
397 0 : data.second = sp.refcount_ptr();
398 :
399 : return sp;
400 : }
401 : // else return present data
402 : {
403 : // NOTE AND WARNING: This is very hacky and dangerous. We only do this
404 : // since we exactly know what we are doing and everything is save and
405 : // only in protected or private area. However, if you once want to change
406 : // this code, please be aware, that we store here plain pointers and
407 : // associated reference counters of a SmartPtr. This should not be done
408 : // in general and this kind of coding is not recommended at all. Please
409 : // use different approaches whenever possible.
410 : std::pair<LuaUserData<TData,dim,TRet>*, int*>& data = iter->second;
411 0 : return SmartPtr<LuaUserData<TData,dim,TRet> >(data.first, data.second);
412 : }
413 : }
414 :
415 : template <typename TData, int dim, typename TRet>
416 : void
417 0 : LuaUserDataFactory<TData,dim,TRet>::remove(const std::string& name)
418 : {
419 : typedef std::map<std::string, std::pair<LuaUserData<TData,dim,TRet>*, int*> > Map;
420 : typedef typename Map::iterator iterator;
421 :
422 : // check for element
423 : iterator iter = m_mData.find(name);
424 :
425 : // if name does not exist, create new one
426 0 : if(iter == m_mData.end())
427 0 : UG_THROW("LuaUserDataFactory: trying to remove non-registered"
428 : " data with name: "<<name);
429 :
430 : m_mData.erase(iter);
431 0 : }
432 :
433 :
434 : // instantiation of static member
435 : template <typename TData, int dim, typename TRet>
436 : std::map<std::string, std::pair<LuaUserData<TData,dim,TRet>*, int*> >
437 : LuaUserDataFactory<TData,dim,TRet>::m_mData = std::map<std::string, std::pair<LuaUserData<TData,dim,TRet>*, int*> >();
438 :
439 : ////////////////////////////////////////////////////////////////////////////////
440 : // LuaUserFunction
441 : ////////////////////////////////////////////////////////////////////////////////
442 :
443 : template <typename TData, int dim, typename TDataIn>
444 0 : LuaUserFunction<TData,dim,TDataIn>::
445 : LuaUserFunction(const char* luaCallback, size_t numArgs)
446 0 : : m_numArgs(numArgs), m_bPosTimeNeed(false)
447 : {
448 0 : m_L = ug::script::GetDefaultLuaState();
449 0 : m_cbValueRef = LUA_NOREF;
450 : m_cbDerivRef.clear();
451 : m_cbDerivName.clear();
452 0 : set_lua_value_callback(luaCallback, numArgs);
453 : #ifdef USE_LUA2C
454 : if(useLuaCompiler) m_luaComp.create(luaCallback);
455 : #endif
456 0 : }
457 :
458 : template <typename TData, int dim, typename TDataIn>
459 0 : LuaUserFunction<TData,dim,TDataIn>::
460 : LuaUserFunction(const char* luaCallback, size_t numArgs, bool bPosTimeNeed)
461 0 : : m_numArgs(numArgs), m_bPosTimeNeed(bPosTimeNeed)
462 : {
463 0 : m_L = ug::script::GetDefaultLuaState();
464 0 : m_cbValueRef = LUA_NOREF;
465 : m_cbDerivRef.clear();
466 : m_cbDerivName.clear();
467 0 : set_lua_value_callback(luaCallback, numArgs);
468 : #ifdef USE_LUA2C
469 : m_luaComp_Deriv.clear();
470 : #endif
471 0 : }
472 :
473 :
474 : template <typename TData, int dim, typename TDataIn>
475 0 : LuaUserFunction<TData,dim,TDataIn>::
476 : LuaUserFunction(LuaFunctionHandle handle, size_t numArgs)
477 0 : : m_numArgs(numArgs), m_bPosTimeNeed(false)
478 : {
479 0 : m_L = ug::script::GetDefaultLuaState();
480 0 : m_cbValueRef = LUA_NOREF;
481 : m_cbDerivRef.clear();
482 : m_cbDerivName.clear();
483 0 : set_lua_value_callback(handle, numArgs);
484 : #ifdef USE_LUA2C
485 : if(useLuaCompiler){
486 : UG_LOG("WARNING (in LuaUserFunction): LUA2C compiler "
487 : "can't be executed for FunctionHandle.\n");
488 : }
489 : #endif
490 0 : }
491 :
492 : template <typename TData, int dim, typename TDataIn>
493 0 : LuaUserFunction<TData,dim,TDataIn>::
494 : LuaUserFunction(LuaFunctionHandle handle, size_t numArgs, bool bPosTimeNeed)
495 0 : : m_numArgs(numArgs), m_bPosTimeNeed(bPosTimeNeed)
496 : {
497 0 : m_L = ug::script::GetDefaultLuaState();
498 0 : m_cbValueRef = LUA_NOREF;
499 : m_cbDerivRef.clear();
500 : m_cbDerivName.clear();
501 0 : set_lua_value_callback(handle, numArgs);
502 : #ifdef USE_LUA2C
503 : m_luaComp_Deriv.clear();
504 : #endif
505 0 : }
506 :
507 :
508 :
509 : template <typename TData, int dim, typename TDataIn>
510 0 : LuaUserFunction<TData,dim,TDataIn>::~LuaUserFunction()
511 : {
512 : // free reference to callback
513 : free_callback_ref();
514 :
515 : // free references to derivate callbacks
516 0 : for(size_t i = 0; i < m_numArgs; ++i){
517 0 : free_deriv_callback_ref(i);
518 : }
519 0 : }
520 :
521 : template <typename TData, int dim, typename TDataIn>
522 : void LuaUserFunction<TData,dim,TDataIn>::free_callback_ref()
523 : {
524 0 : if(m_cbValueRef != LUA_NOREF){
525 0 : luaL_unref(m_L, LUA_REGISTRYINDEX, m_cbValueRef);
526 0 : m_cbValueRef = LUA_NOREF;
527 : }
528 : }
529 :
530 : template <typename TData, int dim, typename TDataIn>
531 0 : void LuaUserFunction<TData,dim,TDataIn>::free_deriv_callback_ref(size_t arg)
532 : {
533 0 : if(m_cbDerivRef[arg] != LUA_NOREF){
534 0 : luaL_unref(m_L, LUA_REGISTRYINDEX, m_cbDerivRef[arg]);
535 0 : m_cbDerivRef[arg] = LUA_NOREF;
536 : }
537 0 : }
538 :
539 :
540 : template <typename TData, int dim, typename TDataIn>
541 0 : void LuaUserFunction<TData,dim,TDataIn>::set_lua_value_callback(const char* luaCallback, size_t numArgs)
542 : {
543 : // store name (string) of callback
544 0 : m_cbValueName = luaCallback;
545 :
546 : // obtain a reference
547 0 : lua_getglobal(m_L, m_cbValueName.c_str());
548 :
549 : // make sure that the reference is valid
550 0 : if(lua_isnil(m_L, -1)){
551 0 : UG_THROW("LuaUserFunction::set_lua_value_callback(...):"
552 : "Specified callback does not exist: " << m_cbValueName);
553 : }
554 :
555 : // if a callback was already set, we have to free the old one
556 : free_callback_ref();
557 :
558 : // store reference to lua function
559 0 : m_cbValueRef = luaL_ref(m_L, LUA_REGISTRYINDEX);
560 :
561 : // remember number of arguments to be used
562 0 : m_numArgs = numArgs;
563 0 : m_cbDerivName.resize(numArgs);
564 0 : m_cbDerivRef.resize(numArgs, LUA_NOREF);
565 :
566 : // set num inputs for linker
567 0 : set_num_input(numArgs);
568 :
569 : #ifdef USE_LUA2C
570 : m_luaComp_Deriv.resize(numArgs);
571 : #endif
572 0 : }
573 :
574 : template <typename TData, int dim, typename TDataIn>
575 0 : void LuaUserFunction<TData,dim,TDataIn>::
576 : set_lua_value_callback(LuaFunctionHandle handle, size_t numArgs)
577 : {
578 : // store name (string) of callback
579 0 : m_cbValueName = "__anonymous__lua__function__";
580 :
581 : // if a callback was already set, we have to free the old one
582 : free_callback_ref();
583 :
584 : // store reference to lua function
585 0 : m_cbValueRef = handle.ref;
586 :
587 : // remember number of arguments to be used
588 0 : m_numArgs = numArgs;
589 0 : m_cbDerivName.resize(numArgs);
590 0 : m_cbDerivRef.resize(numArgs, LUA_NOREF);
591 :
592 : // set num inputs for linker
593 0 : set_num_input(numArgs);
594 :
595 : #ifdef USE_LUA2C
596 : m_luaComp_Deriv.resize(numArgs);
597 : #endif
598 0 : }
599 :
600 : template <typename TData, int dim, typename TDataIn>
601 0 : void LuaUserFunction<TData,dim,TDataIn>::set_deriv(size_t arg, const char* luaCallback)
602 : {
603 : // check number of arg
604 0 : if(arg >= m_numArgs)
605 0 : UG_THROW("LuaUserFunction::set_lua_deriv_callback: Trying "
606 : "to set a derivative for argument " << arg <<", that "
607 : "does not exist. Number of arguments is "<<m_numArgs);
608 :
609 : // store name (string) of callback
610 : m_cbDerivName[arg] = luaCallback;
611 :
612 : // free old reference
613 0 : free_deriv_callback_ref(arg);
614 :
615 : // obtain a reference
616 0 : lua_getglobal(m_L, m_cbDerivName[arg].c_str());
617 :
618 : // make sure that the reference is valid
619 0 : if(lua_isnil(m_L, -1)){
620 0 : UG_THROW("LuaUserFunction::set_lua_deriv_callback(...):"
621 : "Specified callback does not exist: " << m_cbDerivName[arg]);
622 : }
623 :
624 : // store reference to lua function
625 0 : m_cbDerivRef[arg] = luaL_ref(m_L, LUA_REGISTRYINDEX);
626 :
627 : #ifdef USE_LUA2C
628 : if(useLuaCompiler) m_luaComp_Deriv[arg].create(luaCallback);
629 : #endif
630 :
631 0 : }
632 :
633 : template <typename TData, int dim, typename TDataIn>
634 0 : void LuaUserFunction<TData,dim,TDataIn>::set_deriv(size_t arg, LuaFunctionHandle handle)
635 : {
636 : // check number of arg
637 0 : if(arg >= m_numArgs)
638 0 : UG_THROW("LuaUserFunction::set_lua_deriv_callback: Trying "
639 : "to set a derivative for argument " << arg <<", that "
640 : "does not exist. Number of arguments is "<<m_numArgs);
641 :
642 : // store name (string) of callback
643 0 : m_cbDerivName[arg] = std::string("__anonymous__lua__function__");
644 :
645 : // free old reference
646 0 : free_deriv_callback_ref(arg);
647 :
648 : // store reference to lua function
649 0 : m_cbDerivRef[arg] = handle.ref;
650 :
651 : #ifdef USE_LUA2C
652 : // if(useLuaCompiler) m_luaComp_Deriv[arg].create(luaCallback);
653 : #endif
654 :
655 0 : }
656 :
657 :
658 :
659 :
660 : template <typename TData, int dim, typename TDataIn>
661 0 : void LuaUserFunction<TData,dim,TDataIn>::operator() (TData& out, int numArgs, ...) const
662 : {
663 : PROFILE_CALLBACK();
664 : #ifdef USE_LUA2C
665 : if(useLuaCompiler && m_luaComp.is_valid())
666 : {
667 : double d[20];
668 : // get list of arguments
669 : va_list ap2;
670 : va_start(ap2, numArgs);
671 :
672 : // read all arguments and push them to the lua stack
673 : for(int i = 0; i < numArgs; ++i)
674 : d[i] = va_arg(ap2, double);
675 : va_end(ap2);
676 :
677 : double ret[lua_traits<TData>::size+1];
678 :
679 : UG_ASSERT(m_luaComp.num_in() == numArgs && m_luaComp.num_out() == lua_traits<TData>::size,
680 : m_luaComp.name() << ", " << m_luaComp.num_in() << " != " << numArgs << " or " << m_luaComp.num_out() << " != " << lua_traits<TData>::size);
681 : m_luaComp.call(ret, d);
682 : //TData D2;
683 : void *t=NULL;
684 : //TData out2;
685 : lua_traits<TData>::read(out, ret, t);
686 : return;
687 : }
688 : else
689 : #endif
690 : {
691 : UG_ASSERT(numArgs == (int)m_numArgs, "Number of arguments mismatched.");
692 :
693 : // push the callback function on the stack
694 0 : lua_rawgeti(m_L, LUA_REGISTRYINDEX, m_cbValueRef);
695 :
696 : // get list of arguments
697 : va_list ap;
698 0 : va_start(ap, numArgs);
699 :
700 : // read all arguments and push them to the lua stack
701 0 : for(int i = 0; i < numArgs; ++i)
702 : {
703 : // cast data
704 0 : TDataIn val = va_arg(ap, TDataIn);
705 :
706 : // push data to lua stack
707 0 : lua_traits<TDataIn>::push(m_L, val);
708 : }
709 :
710 : // end read in of parameters
711 0 : va_end(ap);
712 :
713 : // compute total args size
714 : size_t argSize = lua_traits<TDataIn>::size * numArgs;
715 :
716 : // compute total return size
717 : size_t retSize = lua_traits<TData>::size;
718 :
719 : // call lua function
720 0 : if(lua_pcall(m_L, argSize, retSize, 0) != 0)
721 0 : UG_THROW("LuaUserFunction::operator(...): Error while "
722 : "running callback '" << m_cbValueName << "',"
723 : " lua message: "<< lua_tostring(m_L, -1));
724 :
725 : try{
726 : // read return value
727 0 : lua_traits<TData>::read(m_L, out);
728 0 : UG_COND_THROW(IsFiniteAndNotTooBig(out)==false, out);
729 : }
730 0 : UG_CATCH_THROW("LuaUserFunction::operator(...): Error while running "
731 : "callback '" << m_cbValueName << "'");
732 :
733 : // pop values
734 0 : lua_pop(m_L, retSize);
735 : }
736 0 : }
737 :
738 :
739 : template <typename TData, int dim, typename TDataIn>
740 0 : void LuaUserFunction<TData,dim,TDataIn>::eval_value(TData& out, const std::vector<TDataIn>& dataIn,
741 : const MathVector<dim>& x, number time, int si) const
742 : {
743 : PROFILE_CALLBACK();
744 : #ifdef USE_LUA2C
745 : if(useLuaCompiler && m_luaComp.is_valid())
746 : {
747 : double d[20];
748 :
749 : // read all arguments and push them to the lua stack
750 : for(size_t i = 0; i < dataIn.size(); ++i)
751 : d[i] = dataIn[i];
752 : if(m_bPosTimeNeed){
753 : for(int i=0; i<dim; i++)
754 : d[i+m_numArgs] = x[i];
755 : d[dim+m_numArgs]=time;
756 : d[dim+m_numArgs+1]=si;
757 : UG_ASSERT(dim+m_numArgs+1 < 20, m_luaComp.name());
758 : }
759 :
760 : double ret[lua_traits<TData>::size];
761 : m_luaComp.call(ret, d);
762 : //TData D2;
763 : void *t=NULL;
764 : //TData out2;
765 : UG_ASSERT(m_luaComp.num_out() == lua_traits<TData>::size, m_luaComp.name() << ", " << m_luaComp.num_out() << " != " << lua_traits<TData>::size);
766 : lua_traits<TData>::read(out, ret, t);
767 : return;
768 : }
769 : else
770 : #endif
771 : {
772 : UG_ASSERT(dataIn.size() == m_numArgs, "Number of arguments mismatched.");
773 :
774 : // push the callback function on the stack
775 0 : lua_rawgeti(m_L, LUA_REGISTRYINDEX, m_cbValueRef);
776 :
777 : // read all arguments and push them to the lua stack
778 0 : for(size_t i = 0; i < dataIn.size(); ++i)
779 : {
780 : // push data to lua stack
781 0 : lua_traits<TDataIn>::push(m_L, dataIn[i]);
782 : }
783 :
784 : // if needed, read additional coordinate, time and subset index arguments and push them to the lua stack
785 0 : if(m_bPosTimeNeed){
786 0 : lua_traits<MathVector<dim> >::push(m_L, x);
787 0 : lua_traits<number>::push(m_L, time);
788 0 : lua_traits<int>::push(m_L, si);
789 : }
790 :
791 : // compute total args size
792 : size_t argSize = lua_traits<TDataIn>::size * dataIn.size();
793 0 : if(m_bPosTimeNeed){
794 0 : argSize += lua_traits<MathVector<dim> >::size
795 : + lua_traits<number>::size
796 : + lua_traits<int>::size;
797 : }
798 :
799 : // compute total return size
800 : size_t retSize = lua_traits<TData>::size;
801 :
802 : // call lua function
803 0 : if(lua_pcall(m_L, argSize, retSize, 0) != 0)
804 0 : UG_THROW("LuaUserFunction::eval_value(...): Error while "
805 : "running callback '" << m_cbValueName << "',"
806 : " lua message: "<< lua_tostring(m_L, -1));
807 :
808 : try{
809 : // read return value
810 0 : lua_traits<TData>::read(m_L, out);
811 0 : UG_COND_THROW(IsFiniteAndNotTooBig(out)==false, out);
812 : }
813 0 : UG_CATCH_THROW("LuaUserFunction::eval_value(...): Error while "
814 : "running callback '" << m_cbValueName << "'");
815 :
816 : // pop values
817 0 : lua_pop(m_L, retSize);
818 : }
819 0 : }
820 :
821 :
822 : template <typename TData, int dim, typename TDataIn>
823 0 : void LuaUserFunction<TData,dim,TDataIn>::eval_deriv(TData& out, const std::vector<TDataIn>& dataIn,
824 : const MathVector<dim>& x, number time, int si, size_t arg) const
825 : {
826 : PROFILE_CALLBACK();
827 : #ifdef USE_LUA2C
828 : if(useLuaCompiler && m_luaComp_Deriv[arg].is_valid()
829 : && dim+m_numArgs+1 < 20 && m_luaComp_Deriv[arg].num_out() == lua_traits<TData>::size)
830 : {
831 : const bridge::LUACompiler &luaComp = m_luaComp_Deriv[arg];
832 : double d[25];
833 : UG_ASSERT(dim+m_numArgs+1 < 20, luaComp.name());
834 : for(size_t i=0; i<m_numArgs; i++)
835 : d[i] = dataIn[i];
836 : if(m_bPosTimeNeed){
837 : for(int i=0; i<dim; i++)
838 : d[i+m_numArgs] = x[i];
839 : d[dim+m_numArgs]=time;
840 : d[dim+m_numArgs+1]=si;
841 : UG_ASSERT(dim+m_numArgs+1 < 20, luaComp.name());
842 : }
843 : UG_ASSERT(luaComp.num_out() == lua_traits<TData>::size,
844 : luaComp.name() << " has wrong number of outputs: is " << luaComp.num_out() << ", needs " << lua_traits<TData>::size);
845 : double ret[lua_traits<TData>::size+1];
846 : luaComp.call(ret, d);
847 : //TData D2;
848 : void *t=NULL;
849 : //TData out2;
850 : lua_traits<TData>::read(out, ret, t);
851 : return;
852 : }
853 : else
854 : #endif
855 : {
856 : UG_ASSERT(dataIn.size() == m_numArgs, "Number of arguments mismatched.");
857 : UG_ASSERT(arg < m_numArgs, "Argument does not exist.");
858 :
859 : // push the callback function on the stack
860 0 : lua_rawgeti(m_L, LUA_REGISTRYINDEX, m_cbDerivRef[arg]);
861 :
862 : // read all arguments and push them to the lua stack
863 0 : for(size_t i = 0; i < dataIn.size(); ++i)
864 : {
865 : // push data to lua stack
866 0 : lua_traits<TDataIn>::push(m_L, dataIn[i]);
867 : }
868 :
869 : // if needed, read additional coordinate, time and subset index arguments and push them to the lua stack
870 0 : if(m_bPosTimeNeed){
871 0 : lua_traits<MathVector<dim> >::push(m_L, x);
872 0 : lua_traits<number>::push(m_L, time);
873 0 : lua_traits<int>::push(m_L, si);
874 : }
875 :
876 : // compute total args size
877 : size_t argSize = lua_traits<TDataIn>::size * dataIn.size();
878 0 : if(m_bPosTimeNeed){
879 0 : argSize += lua_traits<MathVector<dim> >::size
880 : + lua_traits<number>::size
881 : + lua_traits<int>::size;
882 : }
883 :
884 : // compute total return size
885 : size_t retSize = lua_traits<TData>::size;
886 :
887 : // call lua function
888 0 : if(lua_pcall(m_L, argSize, retSize, 0) != 0)
889 0 : UG_THROW("LuaUserFunction::eval_deriv: Error while "
890 : "running callback '" << m_cbDerivName[arg] << "',"
891 : " lua message: "<< lua_tostring(m_L, -1) );
892 :
893 : try{
894 : // read return value
895 0 : lua_traits<TData>::read(m_L, out);
896 0 : UG_COND_THROW(IsFiniteAndNotTooBig(out)==false, out);
897 : }
898 0 : UG_CATCH_THROW("LuaUserFunction::eval_deriv(...): Error while "
899 : "running callback '" << m_cbDerivName[arg] << "'");
900 :
901 : // pop values
902 0 : lua_pop(m_L, retSize);
903 : }
904 0 : }
905 :
906 :
907 : template <typename TData, int dim, typename TDataIn>
908 0 : void LuaUserFunction<TData,dim,TDataIn>::
909 : evaluate (TData& value,
910 : const MathVector<dim>& globIP,
911 : number time, int si) const
912 : {
913 : PROFILE_CALLBACK();
914 : // vector of data for all inputs
915 0 : std::vector<TDataIn> vDataIn(this->num_input());
916 :
917 : // gather all input data for this ip
918 0 : for(size_t c = 0; c < vDataIn.size(); ++c)
919 0 : (*m_vpUserData[c])(vDataIn[c], globIP, time, si);
920 :
921 : // evaluate data at ip
922 0 : eval_value(value, vDataIn, globIP, time, si);
923 :
924 0 : UG_COND_THROW(IsFiniteAndNotTooBig(value)==false, value);
925 0 : }
926 :
927 : template <typename TData, int dim, typename TDataIn>
928 : template <int refDim>
929 0 : void LuaUserFunction<TData,dim,TDataIn>::
930 : evaluate(TData vValue[],
931 : const MathVector<dim> vGlobIP[],
932 : number time, int si,
933 : GridObject* elem,
934 : const MathVector<dim> vCornerCoords[],
935 : const MathVector<refDim> vLocIP[],
936 : const size_t nip,
937 : LocalVector* u,
938 : const MathMatrix<refDim, dim>* vJT) const
939 : {
940 : PROFILE_CALLBACK();
941 : // vector of data for all inputs
942 0 : std::vector<TDataIn> vDataIn(this->num_input());
943 :
944 : // gather all input data for this ip
945 0 : for(size_t ip = 0; ip < nip; ++ip)
946 : {
947 0 : for(size_t c = 0; c < vDataIn.size(); ++c)
948 0 : (*m_vpUserData[c])(vDataIn[c], vGlobIP[ip], time, si, elem, vCornerCoords, vLocIP[ip], u);
949 :
950 : // evaluate data at ip
951 0 : eval_value(vValue[ip], vDataIn, vGlobIP[ip], time, si);
952 0 : UG_COND_THROW(IsFiniteAndNotTooBig(vValue[ip])==false, vValue[ip]);
953 : }
954 0 : }
955 :
956 : template <typename TData, int dim, typename TDataIn>
957 : template <int refDim>
958 0 : void LuaUserFunction<TData,dim,TDataIn>::
959 : eval_and_deriv(TData vValue[],
960 : const MathVector<dim> vGlobIP[],
961 : number time, int si,
962 : GridObject* elem,
963 : const MathVector<dim> vCornerCoords[],
964 : const MathVector<refDim> vLocIP[],
965 : const size_t nip,
966 : LocalVector* u,
967 : bool bDeriv,
968 : int s,
969 : std::vector<std::vector<TData> > vvvDeriv[],
970 : const MathMatrix<refDim, dim>* vJT)
971 : {
972 : PROFILE_CALLBACK();
973 : // vector of data for all inputs
974 0 : std::vector<TDataIn> vDataIn(this->num_input());
975 :
976 0 : for(size_t ip = 0; ip < nip; ++ip)
977 : {
978 : // gather all input data for this ip
979 0 : for(size_t c = 0; c < vDataIn.size(); ++c)
980 0 : vDataIn[c] = m_vpUserData[c]->value(this->series_id(c,s), ip);
981 :
982 : // evaluate data at ip
983 0 : eval_value(vValue[ip], vDataIn, vGlobIP[ip], time, si);
984 : }
985 :
986 : // check if derivative is required
987 0 : if(!bDeriv || this->zero_derivative()) return;
988 :
989 : // clear all derivative values
990 0 : this->set_zero(vvvDeriv, nip);
991 :
992 : // loop all inputs
993 0 : for(size_t c = 0; c < vDataIn.size(); ++c)
994 : {
995 : // check if we have the derivative w.r.t. this input, and the input has derivative
996 0 : if(m_cbDerivRef[c] == LUA_NOREF || m_vpUserData[c]->zero_derivative()) continue;
997 :
998 : // loop ips
999 0 : for(size_t ip = 0; ip < nip; ++ip)
1000 : {
1001 : // gather all input data for this ip
1002 0 : for(size_t i = 0; i < vDataIn.size(); ++i)
1003 0 : vDataIn[i] = m_vpUserData[i]->value(this->series_id(c,s), ip); //< series_id(c,s) or series_id(i,s)
1004 :
1005 : // data of derivative w.r.t. one component at ip-values
1006 : TData derivVal;
1007 :
1008 : // evaluate data at ip
1009 0 : eval_deriv(derivVal, vDataIn, vGlobIP[ip], time, si, c);
1010 :
1011 : // loop functions
1012 0 : for(size_t fct = 0; fct < this->input_num_fct(c); ++fct)
1013 : {
1014 : // get common fct id for this function
1015 : const size_t commonFct = this->input_common_fct(c, fct);
1016 :
1017 : // loop dofs
1018 0 : for(size_t dof = 0; dof < this->num_sh(fct); ++dof)
1019 : {
1020 : linker_traits<TData, TDataIn>::
1021 0 : mult_add(vvvDeriv[ip][commonFct][dof],
1022 : derivVal,
1023 0 : m_vpDependData[c]->deriv(this->series_id(c,s), ip, fct, dof));
1024 0 : UG_COND_THROW(IsFiniteAndNotTooBig(vvvDeriv[ip][commonFct][dof])==false, vvvDeriv[ip][commonFct][dof]);
1025 : }
1026 : }
1027 : }
1028 : }
1029 0 : }
1030 :
1031 : /**
1032 : * TODO: Note this is a public (non-virtual) function whose argument
1033 : * should be consistent with the number of the arguments. Should not it also
1034 : * resize the array for the references to the derivatives?
1035 : */
1036 : template <typename TData, int dim, typename TDataIn>
1037 0 : void LuaUserFunction<TData,dim,TDataIn>::set_num_input(size_t num)
1038 : {
1039 : // resize arrays
1040 0 : m_vpUserData.resize(num);
1041 0 : m_vpDependData.resize(num);
1042 :
1043 : // forward size to base class
1044 : base_type::set_num_input(num);
1045 0 : }
1046 :
1047 : template <typename TData, int dim, typename TDataIn>
1048 0 : void LuaUserFunction<TData,dim,TDataIn>::
1049 : set_input(size_t i, SmartPtr<CplUserData<TDataIn, dim> > data)
1050 : {
1051 : UG_ASSERT(i < m_vpUserData.size(), "Input not needed");
1052 : UG_ASSERT(i < m_vpDependData.size(), "Input not needed");
1053 :
1054 : // check input number
1055 0 : if(i >= this->num_input())
1056 0 : UG_THROW("LuaUserFunction::set_input: Only " << this->num_input()
1057 : << " inputs can be set. Use 'set_num_input' to increase"
1058 : " the number of needed inputs.");
1059 :
1060 : // remember userdata
1061 0 : m_vpUserData[i] = data;
1062 :
1063 : // cast to dependent data
1064 0 : m_vpDependData[i] = data.template cast_dynamic<DependentUserData<TDataIn, dim> >();
1065 :
1066 : // forward to base class
1067 0 : base_type::set_input(i, data, data);
1068 0 : }
1069 :
1070 : template <typename TData, int dim, typename TDataIn>
1071 0 : void LuaUserFunction<TData,dim,TDataIn>::set_input(size_t i, number val)
1072 : {
1073 0 : set_input(i, CreateConstUserData<dim>(val, TDataIn()));
1074 0 : }
1075 :
1076 :
1077 : ////////////////////////////////////////////////////////////////////////////////
1078 : // LuaFunction
1079 : ////////////////////////////////////////////////////////////////////////////////
1080 :
1081 : template <typename TData, typename TDataIn>
1082 0 : LuaFunction<TData,TDataIn>::LuaFunction() : m_numArgs(0)
1083 : {
1084 0 : m_L = ug::script::GetDefaultLuaState();
1085 0 : m_cbValueRef = LUA_NOREF;
1086 0 : }
1087 :
1088 : template <typename TData, typename TDataIn>
1089 0 : void LuaFunction<TData,TDataIn>::set_lua_callback(const char* luaCallback, size_t numArgs)
1090 : {
1091 : // store name (string) of callback
1092 0 : m_cbValueName = luaCallback;
1093 :
1094 : // obtain a reference
1095 0 : lua_getglobal(m_L, m_cbValueName.c_str());
1096 :
1097 : // make sure that the reference is valid
1098 0 : if(lua_isnil(m_L, -1)){
1099 0 : UG_THROW("LuaFunction::set_lua_callback(...):"
1100 : "Specified lua callback does not exist: " << m_cbValueName);
1101 : }
1102 :
1103 : // store reference to lua function
1104 0 : m_cbValueRef = luaL_ref(m_L, LUA_REGISTRYINDEX);
1105 :
1106 : // remember number of arguments to be used
1107 0 : m_numArgs = numArgs;
1108 0 : }
1109 :
1110 : template <typename TData, typename TDataIn>
1111 0 : void LuaFunction<TData,TDataIn>::operator() (TData& out, int numArgs, ...)
1112 : {
1113 : PROFILE_CALLBACK_BEGIN(operatorBracket);
1114 : UG_ASSERT(numArgs == (int)m_numArgs, "Number of arguments mismatched.");
1115 :
1116 : // push the callback function on the stack
1117 0 : lua_rawgeti(m_L, LUA_REGISTRYINDEX, m_cbValueRef);
1118 :
1119 : // get list of arguments
1120 : va_list ap;
1121 0 : va_start(ap, numArgs);
1122 :
1123 : // read all arguments and push them to the lua stack
1124 0 : for(int i = 0; i < numArgs; ++i)
1125 : {
1126 : // cast data
1127 0 : TDataIn val = va_arg(ap, TDataIn);
1128 :
1129 : // push data to lua stack
1130 0 : lua_traits<TDataIn>::push(m_L, val);
1131 : }
1132 :
1133 : // end read in of parameters
1134 0 : va_end(ap);
1135 :
1136 : // compute total args size
1137 : size_t argSize = lua_traits<TDataIn>::size * numArgs;
1138 :
1139 : // compute total return size
1140 : size_t retSize = lua_traits<TData>::size;
1141 :
1142 : // call lua function
1143 0 : if(lua_pcall(m_L, argSize, retSize, 0) != 0)
1144 0 : UG_THROW("LuaFunction::operator(...): Error while "
1145 : "running callback '" << m_cbValueName << "',"
1146 : " lua message: "<< lua_tostring(m_L, -1));
1147 :
1148 : try{
1149 : // read return value
1150 0 : lua_traits<TData>::read(m_L, out);
1151 0 : UG_COND_THROW(IsFiniteAndNotTooBig(out)==false, out);
1152 : }
1153 0 : UG_CATCH_THROW("LuaFunction::operator(...): Error while running "
1154 : "callback '" << m_cbValueName << "'");
1155 :
1156 : // pop values
1157 0 : lua_pop(m_L, retSize);
1158 :
1159 : PROFILE_CALLBACK_END();
1160 0 : }
1161 :
1162 :
1163 :
1164 : } // end namespace ug
1165 :
1166 : #endif /* LUA_USER_DATA_IMPL_H_ */
|