1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
| #include <stdio.h> #include <iostream> #include "python2.7/Python.h"
PyObject* New_PyInstance(PyObject* cls, PyObject* args) { PyObject* pInstance = PyInstance_New(cls, args, NULL); if (!pInstance) { std::cerr << "new instance failed" << std::endl; exit(1); } return pInstance; }
int main(int argc, char* argv[]) { Py_Initialize(); PyRun_SimpleString("import sys"); PyRun_SimpleString("sys.path.append('./')");
PyObject* moduleName = PyString_FromString("tree"); PyObject* pModule = PyImport_Import(moduleName); if (!pModule) { std::cerr << "[ERROR] Python get module failed." << std::endl; return 1; } PyObject* pEnv = PyObject_GetAttrString(pModule, "environment"); if (!pEnv) { std::cerr << "[ERROR] Can't find class (environment)" << std::endl; return 1; }
PyObject* pEnvObject = New_PyInstance(pEnv, NULL); PyObject* pEnvLevel = PyObject_GetAttrString(pEnvObject, "level"); if (!pEnvLevel) { std::cerr << "[ERROR] Env has no attr level" << std::endl; return 1; } PyObject* pEnvActions = PyObject_GetAttrString(pEnvObject, "actions"); PyObject* pEnvStates = PyObject_GetAttrString(pEnvObject, "states"); PyObject* pEnvFinalState = PyObject_GetAttrString(pEnvObject, "final_states");
int level = PyInt_AsLong(pEnvLevel); int actions = PyInt_AsLong(pEnvActions); int states = PyInt_AsLong(pEnvStates); int final_state = PyInt_AsLong(pEnvFinalState);
std::cout << "env level: " << level << std::endl; std::cout << "env actions: " << actions << std::endl; std::cout << "env states: " << states << std::endl; std::cout << "env final_state: " << final_state << std::endl;
PyObject* pLearn = PyObject_GetAttrString(pModule, "q_learning"); PyObject* pLearnArgs = Py_BuildValue("ii", states, actions); PyObject* pLearnObject = New_PyInstance(pLearn, pLearnArgs); PyObject* pLearnStates = PyObject_GetAttrString(pLearnObject, "states"); PyObject* pLearnActions = PyObject_GetAttrString(pLearnObject, "actions"); PyObject* pLearnEps = PyObject_GetAttrString(pLearnObject, "eps");
int learn_states = PyInt_AsLong(pLearnStates); int learn_actions = PyInt_AsLong(pLearnActions); float learn_eps = PyFloat_AsDouble(pLearnEps);
std::cout << "learn_states: " << learn_states << std::endl; std::cout << "learn_actions: " << learn_actions << std::endl; std::cout << "learn_eps: " << learn_eps << std::endl;
PyObject* pEnvResetFunc = PyObject_GetAttrString(pEnvObject, "reset"); PyObject* pEnvNextFunc = PyObject_GetAttrString(pEnvObject, "next"); PyObject* pLearnGetActionFunc = PyObject_GetAttrString(pLearnObject, "get_action"); PyObject* pLearnUpdateFunc = PyObject_GetAttrString(pLearnObject, "update"); if (!pEnvNextFunc) { std::cerr << "[ERROR] env has no function named next" << std::endl; return 1; }
std::cout << std::endl; uint64_t episode = 0; for (episode = 0; episode < 10000; ++episode) { if (episode % 100 == 0) std::cout << "episode: " << episode << std::endl; PyObject* current_state = PyEval_CallObject(pEnvResetFunc, NULL); while (true) { PyObject* args1 = PyTuple_New(1); PyObject* args2 = PyTuple_New(2); PyTuple_SetItem(args1, 0, current_state); PyObject* action = PyEval_CallObject(pLearnGetActionFunc, args1); PyTuple_SetItem(args2, 0, current_state); PyTuple_SetItem(args2, 1, action); PyObject* ret = PyEval_CallObject(pEnvNextFunc, args2); PyObject* next_state = PyTuple_GetItem(ret, 0); PyObject* final = PyTuple_GetItem(ret ,2); PyObject* args3 = PyTuple_New(5); PyTuple_SetItem(args3, 0, current_state); PyTuple_SetItem(args3, 1, action); PyTuple_SetItem(args3, 2, next_state); PyTuple_SetItem(args3, 3, PyTuple_GetItem(ret, 1)); PyTuple_SetItem(args3, 4, final);
PyEval_CallObject(pLearnUpdateFunc, args3); if (PyObject_IsTrue(final)) { break; } current_state = next_state; if (args3) Py_DECREF(args3); } } PyObject* pLearnQTable = PyObject_GetAttrString(pLearnObject, "q_table"); for (int i = 0; i < PyList_Size(pLearnQTable); ++i) { std::cout << "state " << i << std::endl; PyObject* term = PyList_GetItem(pLearnQTable, i); if (PyList_Check(term)) { for (int j = 0; j < PyList_Size(term); ++j) { std::cout << " direct: " << j << ", " << "Qvalue: " << PyFloat_AsDouble(PyList_GetItem(term, j)) << std::endl; } } } Py_Finalize(); return 0; }
|